From 814c35b7d71ea398b2c492a3e8313014c423b9f7 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 30 Sep 2018 18:36:22 -0700 Subject: [PATCH 001/215] [rllib] Simplify sample batch size and num envs config, n_step adjustment (#2995) * simplify vec batch requirements * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-training.rst * Update rllib-models.rst --- doc/source/rllib-models.rst | 2 +- doc/source/rllib-training.rst | 32 +++++++++++++++-- python/ray/rllib/agents/dqn/dqn.py | 4 +-- .../agents/impala/vtrace_policy_graph.py | 3 +- .../ray/rllib/evaluation/policy_evaluator.py | 34 ++++++++----------- python/ray/rllib/evaluation/sampler.py | 20 +++++------ .../ray/rllib/test/test_policy_evaluator.py | 20 ++++------- .../ray/rllib/tuned_examples/atari-a2c.yaml | 2 +- .../ray/rllib/tuned_examples/atari-apex.yaml | 2 +- .../rllib/tuned_examples/atari-impala.yaml | 2 +- .../ray/rllib/tuned_examples/atari-ppo.yaml | 2 +- .../pong-impala-vectorized.yaml | 2 +- 12 files changed, 68 insertions(+), 57 deletions(-) diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index a234ba002242..c279855ac89c 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -13,7 +13,7 @@ Built-in Models and Preprocessors RLlib picks default models based on a simple heuristic: a `vision network `__ for image observations, and a `fully connected network `__ for everything else. These models can be configured via the ``model`` config key, documented in the model `catalog `__. Note that you'll probably have to configure ``conv_filters`` if your environment observations have custom sizes, e.g., ``"model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]}`` for 42x42 observations. -In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell `__. More generally, RLlib supports the use of recurrent models for its algorithms (A3C, PG out of the box), and RNN support is built into its policy evaluation utilities. +In addition, if you set ``"model": {"use_lstm": true}``, then the model output will be further processed by a `LSTM cell `__. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities. For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple observations flattened (there isn't native tuple support yet, but you can reshape the flattened observation in a custom model). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 25cd0d893185..6d3a142db154 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -56,7 +56,6 @@ Specifying Resources ~~~~~~~~~~~~~~~~~~~~ You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. Many agents also provide a ``num_gpus`` or ``gpu`` option. In addition, you can allocate a fraction of a GPU by setting ``gpu_fraction: f``. For example, with DQN you can pack five agents onto one GPU by setting ``gpu_fraction: 0.2``. Note that fractional GPU support requires enabling the experimental Xray backend by setting the environment variable ``RAY_USE_XRAY=1``. ->>>>>>> 01b030bd57f014386aa5e4c67a2e069938528abb Evaluating Trained Agents ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -95,7 +94,7 @@ You can run these with the ``train.py`` script as follows: Python API ---------- -The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use custom environments, preprocesors, or models with RLlib. +The Python API provides the needed flexibility for applying RLlib to new problems. You will need to use this API if you wish to use `custom environments, preprocessors, or models `__ with RLlib. Here is an example of the basic usage: @@ -184,11 +183,38 @@ You can also access just the "master" copy of the agent state through ``agent.lo agent.optimizer.foreach_evaluator_with_index( lambda ev, i: ev.for_policy(lambda p: p.get_weights())) +Global Coordination +~~~~~~~~~~~~~~~~~~~ +Sometimes, it is necessary to coordinate between pieces of code that live in different processes managed by RLlib. For example, it can be useful to maintain a global average of a certain variable, or centrally control a hyperparameter used by policies. Ray provides a general way to achieve this through *named actors* (learn more about Ray actors `here `__). As an example, consider maintaining a shared global counter that is incremented by environments and read periodically from your driver program: + +.. code-block:: python + + from ray.experimental import named_actors + + @ray.remote + class Counter: + def __init__(self): + self.count = 0 + def inc(self, n): + self.count += n + def get(self): + return self.count + + # on the driver + counter = Counter.remote() + named_actors.register_actor("global_counter", counter) + print(ray.get(counter.get.remote())) # get the latest count + + # in your envs + counter = named_actors.get_actor("global_counter") + counter.inc.remote(1) # async call to increment the global count + +Ray actors provide high levels of performance, so in more complex cases they can be used implement communication patterns such as parameter servers and allreduce. REST API -------- -In some cases (i.e., when interacting with an external environment) it makes more sense to interact with RLlib as if were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib's serving env `interface `__. +In some cases (i.e., when interacting with an external environment) it makes more sense to interact with RLlib as if were an independently running service, rather than RLlib hosting the simulations itself. This is possible via RLlib's serving env `interface `__. .. autoclass:: ray.rllib.utils.policy_client.PolicyClient :members: diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index c945cdbc9fe8..25320fd6a982 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -135,8 +135,8 @@ def default_resource_request(cls, config): def _init(self): # Update effective batch size to include n-step - adjusted_batch_size = ( - self.config["sample_batch_size"] + self.config["n_step"] - 1) + adjusted_batch_size = max(self.config["sample_batch_size"], + self.config["n_step"]) self.config["sample_batch_size"] = adjusted_batch_size self.exploration0 = self._make_exploration_schedule(0) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 23f88e51f51b..f6984687166c 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -126,8 +126,7 @@ def to_batches(tensor): else: # Important: chop the tensor into batches at known episode cut # boundaries. TODO(ekl) this is kind of a hack - T = (self.config["sample_batch_size"] // - self.config["num_envs_per_worker"]) + T = self.config["sample_batch_size"] B = tf.shape(tensor)[0] // T rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 24eb746100d5..1152aab82b49 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -124,16 +124,14 @@ def __init__(self, in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: "truncate_episodes": Each call to sample() will return a batch - of at most `batch_steps` in size. The batch will be exactly - `batch_steps` in size if postprocessing does not change - batch sizes. Episodes may be truncated in order to meet - this size requirement. When `num_envs > 1`, episodes will - be truncated to sequences of `batch_size / num_envs` in - length. + of at most `batch_steps * num_envs` in size. The batch will + be exactly `batch_steps * num_envs` in size if + postprocessing does not change batch sizes. Episodes may be + truncated in order to meet this size requirement. "complete_episodes": Each call to sample() will return a batch - of at least `batch_steps in size. Episodes will not be - truncated, but multiple episodes may be packed within one - batch to meet the batch size. Note that when + of at least `batch_steps * num_envs` in size. Episodes will + not be truncated, but multiple episodes may be packed + within one batch to meet the batch size. Note that when `num_envs > 1`, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data. @@ -171,7 +169,7 @@ def __init__(self, policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) self.env_creator = env_creator - self.batch_steps = batch_steps + self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode self.compress_observations = compress_observations @@ -246,15 +244,10 @@ def make_env(vector_index): self.num_envs = num_envs if self.batch_mode == "truncate_episodes": - if batch_steps % num_envs != 0: - raise ValueError( - "In 'truncate_episodes' batch mode, `batch_steps` must be " - "evenly divisible by `num_envs`. Got {} and {}.".format( - batch_steps, num_envs)) - batch_steps = batch_steps // num_envs + unroll_length = batch_steps pack_episodes = True elif self.batch_mode == "complete_episodes": - batch_steps = float("inf") # never cut episodes + unroll_length = float("inf") # never cut episodes pack_episodes = False # sampler will return 1 episode per poll else: raise ValueError("Unsupported batch mode: {}".format( @@ -266,7 +259,7 @@ def make_env(vector_index): policy_mapping_fn, self.filters, clip_rewards, - batch_steps, + unroll_length, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess) @@ -278,7 +271,7 @@ def make_env(vector_index): policy_mapping_fn, self.filters, clip_rewards, - batch_steps, + unroll_length, horizon=episode_horizon, pack=pack_episodes, tf_sess=self.tf_sess) @@ -310,7 +303,8 @@ def sample(self): else: max_batches = float("inf") - while steps_so_far < self.batch_steps and len(batches) < max_batches: + while steps_so_far < self.sample_batch_size and len( + batches) < max_batches: batch = self.sampler.get_data() steps_so_far += batch.count batches.append(batch) diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index f41c3ca739e2..ec0f13c4e445 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -36,12 +36,12 @@ def __init__(self, policy_mapping_fn, obs_filters, clip_rewards, - num_local_steps, + unroll_length, horizon=None, pack=False, tf_sess=None): self.async_vector_env = AsyncVectorEnv.wrap_async(env) - self.num_local_steps = num_local_steps + self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn @@ -49,7 +49,7 @@ def __init__(self, self.extra_batches = queue.Queue() self.rollout_provider = _env_runner( self.async_vector_env, self.extra_batches.put, self.policies, - self.policy_mapping_fn, self.num_local_steps, self.horizon, + self.policy_mapping_fn, self.unroll_length, self.horizon, self._obs_filters, clip_rewards, pack, tf_sess) self.metrics_queue = queue.Queue() @@ -92,7 +92,7 @@ def __init__(self, policy_mapping_fn, obs_filters, clip_rewards, - num_local_steps, + unroll_length, horizon=None, pack=False, tf_sess=None): @@ -104,7 +104,7 @@ def __init__(self, self.queue = queue.Queue(5) self.extra_batches = queue.Queue() self.metrics_queue = queue.Queue() - self.num_local_steps = num_local_steps + self.unroll_length = unroll_length self.horizon = horizon self.policies = policies self.policy_mapping_fn = policy_mapping_fn @@ -124,7 +124,7 @@ def run(self): def _run(self): rollout_provider = _env_runner( self.async_vector_env, self.extra_batches.put, self.policies, - self.policy_mapping_fn, self.num_local_steps, self.horizon, + self.policy_mapping_fn, self.unroll_length, self.horizon, self._obs_filters, self.clip_rewards, self.pack, self.tf_sess) while True: # The timeout variable exists because apparently, if one worker @@ -182,7 +182,7 @@ def _env_runner(async_vector_env, extra_batch_callback, policies, policy_mapping_fn, - num_local_steps, + unroll_length, horizon, obs_filters, clip_rewards, @@ -197,14 +197,14 @@ def _env_runner(async_vector_env, policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. - num_local_steps (int): Number of episode steps before `SampleBatch` is + unroll_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. horizon (int): Horizon of the episode. obs_filters (dict): Map of policy id to filter used to process observations for the policy. clip_rewards (bool): Whether to clip rewards before postprocessing. pack (bool): Whether to pack multiple episodes into each batch. This - guarantees batches will be exactly `num_local_steps` in size. + guarantees batches will be exactly `unroll_length` in size. tf_sess (Session|None): Optional tensorflow session to use for batching TF policy evaluations. @@ -306,7 +306,7 @@ def new_episode(): # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if (all_done and not pack) or \ - episode.batch_builder.count >= num_local_steps: + episode.batch_builder.count >= unroll_length: yield episode.batch_builder.build_and_reset() elif all_done: # Make sure postprocessor stays within one episode diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index cc189edbf6e0..c4c2baf6e18e 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -129,9 +129,10 @@ def testQueryEvaluators(self): "num_workers": 2, "sample_batch_size": 5 }) - results = pg.optimizer.foreach_evaluator(lambda ev: ev.batch_steps) + results = pg.optimizer.foreach_evaluator( + lambda ev: ev.sample_batch_size) results2 = pg.optimizer.foreach_evaluator_with_index( - lambda ev, i: (i, ev.batch_steps)) + lambda ev, i: (i, ev.sample_batch_size)) self.assertEqual(results, [5, 5, 5]) self.assertEqual(results2, [(0, 5), (1, 5), (2, 5)]) @@ -198,7 +199,7 @@ def testAutoVectorization(self): env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", - batch_steps=16, + batch_steps=2, num_envs=8) for _ in range(8): batch = ev.sample() @@ -216,21 +217,12 @@ def testAutoVectorization(self): indices.append(env.unwrapped.config.vector_index) self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7]) - def testBatchDivisibilityCheck(self): - self.assertRaises( - ValueError, - lambda: PolicyEvaluator( - env_creator=lambda _: MockEnv(episode_length=8), - policy_graph=MockPolicyGraph, - batch_mode="truncate_episodes", - batch_steps=15, num_envs=4)) - - def testBatchesSmallerWhenVectorized(self): + def testBatchesLargerWhenVectorized(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", - batch_steps=16, + batch_steps=4, num_envs=4) batch = ev.sample() self.assertEqual(batch.count, 16) diff --git a/python/ray/rllib/tuned_examples/atari-a2c.yaml b/python/ray/rllib/tuned_examples/atari-a2c.yaml index 89feaee5ba8b..53d1937cdfd4 100644 --- a/python/ray/rllib/tuned_examples/atari-a2c.yaml +++ b/python/ray/rllib/tuned_examples/atari-a2c.yaml @@ -9,7 +9,7 @@ atari-a2c: - SpaceInvadersNoFrameskip-v4 run: A2C config: - sample_batch_size: 100 + sample_batch_size: 20 clip_rewards: True num_workers: 5 num_envs_per_worker: 5 diff --git a/python/ray/rllib/tuned_examples/atari-apex.yaml b/python/ray/rllib/tuned_examples/atari-apex.yaml index 6e538d038998..19036a32baa1 100644 --- a/python/ray/rllib/tuned_examples/atari-apex.yaml +++ b/python/ray/rllib/tuned_examples/atari-apex.yaml @@ -28,7 +28,7 @@ apex: # APEX num_workers: 8 num_envs_per_worker: 8 - sample_batch_size: 158 + sample_batch_size: 20 train_batch_size: 512 target_network_update_freq: 50000 timesteps_per_iteration: 25000 diff --git a/python/ray/rllib/tuned_examples/atari-impala.yaml b/python/ray/rllib/tuned_examples/atari-impala.yaml index 85bd801ff83b..597b41987b3f 100644 --- a/python/ray/rllib/tuned_examples/atari-impala.yaml +++ b/python/ray/rllib/tuned_examples/atari-impala.yaml @@ -9,7 +9,7 @@ atari-impala: - SpaceInvadersNoFrameskip-v4 run: IMPALA config: - sample_batch_size: 250 # 50 * num_envs_per_worker + sample_batch_size: 50 train_batch_size: 500 num_workers: 32 num_envs_per_worker: 5 diff --git a/python/ray/rllib/tuned_examples/atari-ppo.yaml b/python/ray/rllib/tuned_examples/atari-ppo.yaml index 24593d6bb929..c6be6435041c 100644 --- a/python/ray/rllib/tuned_examples/atari-ppo.yaml +++ b/python/ray/rllib/tuned_examples/atari-ppo.yaml @@ -16,7 +16,7 @@ atari-ppo: vf_clip_param: 10.0 entropy_coeff: 0.01 train_batch_size: 5000 - sample_batch_size: 500 + sample_batch_size: 100 sgd_minibatch_size: 500 num_sgd_iter: 10 num_workers: 10 diff --git a/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml index 9525f4115521..b16488b443b8 100644 --- a/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml +++ b/python/ray/rllib/tuned_examples/pong-impala-vectorized.yaml @@ -5,7 +5,7 @@ pong-impala-vectorized: env: PongNoFrameskip-v4 run: IMPALA config: - sample_batch_size: 500 # 50 * num_envs_per_worker + sample_batch_size: 50 train_batch_size: 500 num_workers: 32 num_envs_per_worker: 10 From e4bea8d10effa45d1fc5b5cb897a1305950880b2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 30 Sep 2018 18:37:55 -0700 Subject: [PATCH 002/215] [rllib] Default to truncate_episodes and add some more config validators (#2967) * update * link it * warn about truncation * fix * Update rllib-training.rst * deprecate tests failing --- python/ray/rllib/agents/ppo/ppo.py | 36 ++++++++++++------- .../ray/rllib/tuned_examples/hopper-ppo.yaml | 1 + .../tuned_examples/humanoid-ppo-gae.yaml | 1 + .../rllib/tuned_examples/humanoid-ppo.yaml | 1 + .../rllib/tuned_examples/pendulum-ppo.yaml | 2 +- .../regression_tests/cartpole-ppo.yaml | 1 + .../regression_tests/pendulum-ppo.yaml | 1 + .../rllib/tuned_examples/walker2d-ppo.yaml | 1 + test/jenkins_tests/run_multi_node_tests.sh | 8 +---- 9 files changed, 32 insertions(+), 20 deletions(-) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index f452f789397e..d2a991929dfa 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -48,7 +48,7 @@ # Whether to allocate CPUs for workers (if > 0). "num_cpus_per_worker": 1, # Whether to rollout "complete_episodes" or "truncate_episodes" - "batch_mode": "complete_episodes", + "batch_mode": "truncate_episodes", # Which observation filter to apply to the observation "observation_filter": "MeanStdFilter", # Use the sync samples optimizer instead of the multi-gpu one @@ -80,17 +80,7 @@ def default_resource_request(cls, config): extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): - waste_ratio = ( - self.config["sample_batch_size"] * self.config["num_workers"] / - self.config["train_batch_size"]) - if waste_ratio > 1: - msg = ("sample_batch_size * num_workers >> train_batch_size. " - "This means that many steps will be discarded. Consider " - "reducing sample_batch_size, or increase train_batch_size.") - if waste_ratio > 1.5: - raise ValueError(msg) - else: - print("Warning: " + msg) + self._validate_config() self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( @@ -114,6 +104,28 @@ def _init(self): "standardize_fields": ["advantages"], }) + def _validate_config(self): + waste_ratio = ( + self.config["sample_batch_size"] * self.config["num_workers"] / + self.config["train_batch_size"]) + if waste_ratio > 1: + msg = ("sample_batch_size * num_workers >> train_batch_size. " + "This means that many steps will be discarded. Consider " + "reducing sample_batch_size, or increase train_batch_size.") + if waste_ratio > 1.5: + raise ValueError(msg) + else: + print("Warning: " + msg) + if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + self.config["sgd_minibatch_size"], + self.config["train_batch_size"])) + if (self.config["batch_mode"] == "truncate_episodes" + and not self.config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value function") + def _train(self): prev_steps = self.optimizer.num_steps_sampled fetches = self.optimizer.step() diff --git a/python/ray/rllib/tuned_examples/hopper-ppo.yaml b/python/ray/rllib/tuned_examples/hopper-ppo.yaml index c1c75b166e7c..5082dc7921e4 100644 --- a/python/ray/rllib/tuned_examples/hopper-ppo.yaml +++ b/python/ray/rllib/tuned_examples/hopper-ppo.yaml @@ -10,3 +10,4 @@ hopper-ppo: train_batch_size: 160000 num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml index e176dcae26c6..9473b5df7a6a 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo-gae.yaml @@ -17,3 +17,4 @@ humanoid-ppo-gae: free_log_std: true num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml index 0608f8b60353..07371d16f712 100644 --- a/python/ray/rllib/tuned_examples/humanoid-ppo.yaml +++ b/python/ray/rllib/tuned_examples/humanoid-ppo.yaml @@ -15,3 +15,4 @@ humanoid-ppo: use_gae: false num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml index 60df6825bd43..b8c0293a3e33 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ppo.yaml @@ -13,4 +13,4 @@ pendulum-ppo: num_sgd_iter: 10 model: fcnet_hiddens: [64, 64] - squash_to_range: True + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml index 425958e5c109..82ea5846e733 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/cartpole-ppo.yaml @@ -6,3 +6,4 @@ cartpole-ppo: time_total_s: 300 config: num_workers: 1 + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml index 8b9d69fce20a..63536d3be370 100644 --- a/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-ppo.yaml @@ -15,3 +15,4 @@ pendulum-ppo: num_sgd_iter: 10 model: fcnet_hiddens: [64, 64] + batch_mode: complete_episodes diff --git a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml index deb5a0038dcb..9d64720a2c5b 100644 --- a/python/ray/rllib/tuned_examples/walker2d-ppo.yaml +++ b/python/ray/rllib/tuned_examples/walker2d-ppo.yaml @@ -9,3 +9,4 @@ walker2d-v1-ppo: train_batch_size: 320000 num_workers: 64 num_gpus: 4 + batch_mode: complete_episodes diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index e12eca455c0b..43815f470cff 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -58,7 +58,7 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ - --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false}' + --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}' docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ @@ -288,12 +288,6 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/genetic_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py - -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py - docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 From b45bed4bce94725bd4fc11c224e555c5fde7e1f4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 1 Oct 2018 12:49:39 -0700 Subject: [PATCH 003/215] [rllib] Propagate model options correctly in ARS / ES, to action dist of PPO (#2974) * fix * fix * fix it * propagate conf to action dist * move carla example too * rr * Update policies.py * wip * lint --- doc/source/example-a3c.rst | 2 +- doc/source/example-policy-gradient.rst | 2 +- doc/source/index.rst | 2 +- doc/source/rllib-env.rst | 2 +- doc/source/rllib-models.rst | 2 +- doc/source/rllib.rst | 2 +- examples/carla/scenarios.py | 119 ------------ examples/custom_env/README | 1 - python/ray/rllib/agents/ars/ars.py | 65 ++----- python/ray/rllib/agents/ars/policies.py | 43 +---- python/ray/rllib/agents/ars/utils.py | 21 --- python/ray/rllib/agents/es/es.py | 12 +- python/ray/rllib/agents/es/policies.py | 6 +- python/ray/rllib/agents/pg/pg_policy_graph.py | 4 +- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 5 +- .../ray/rllib/examples}/carla/README | 0 .../rllib/examples}/carla/a3c_lane_keep.py | 1 - .../rllib/examples}/carla/dqn_lane_keep.py | 4 - .../ray/rllib/examples}/carla/env.py | 174 +++++++++--------- .../ray/rllib/examples}/carla/models.py | 29 ++- .../rllib/examples}/carla/ppo_lane_keep.py | 5 +- python/ray/rllib/examples/carla/scenarios.py | 131 +++++++++++++ .../ray/rllib/examples}/carla/train_a3c.py | 1 - .../ray/rllib/examples}/carla/train_dqn.py | 19 +- .../ray/rllib/examples}/carla/train_ppo.py | 17 +- .../ray/rllib/examples}/custom_env.py | 4 +- python/ray/rllib/models/catalog.py | 17 +- python/ray/rllib/test/test_catalog.py | 5 +- .../ray/rllib/tuned_examples/swimmer-ars.yaml | 4 +- 29 files changed, 322 insertions(+), 377 deletions(-) delete mode 100644 examples/carla/scenarios.py delete mode 100644 examples/custom_env/README rename {examples => python/ray/rllib/examples}/carla/README (100%) rename {examples => python/ray/rllib/examples}/carla/a3c_lane_keep.py (96%) rename {examples => python/ray/rllib/examples}/carla/dqn_lane_keep.py (90%) rename {examples => python/ray/rllib/examples}/carla/env.py (83%) rename {examples => python/ray/rllib/examples}/carla/models.py (83%) rename {examples => python/ray/rllib/examples}/carla/ppo_lane_keep.py (93%) create mode 100644 python/ray/rllib/examples/carla/scenarios.py rename {examples => python/ray/rllib/examples}/carla/train_a3c.py (96%) rename {examples => python/ray/rllib/examples}/carla/train_dqn.py (81%) rename {examples => python/ray/rllib/examples}/carla/train_ppo.py (80%) rename {examples/custom_env => python/ray/rllib/examples}/custom_env.py (93%) diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index 665d49a36551..4a62ec61acbd 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -13,7 +13,7 @@ View the `code for this example`_. .. note:: - For an overview of Ray's reinforcement learning library, see `Ray RLlib `__. + For an overview of Ray's reinforcement learning library, see `RLlib `__. To run the application, first install **ray** and then some dependencies: diff --git a/doc/source/example-policy-gradient.rst b/doc/source/example-policy-gradient.rst index 806764560ba9..cabadfd37a9f 100644 --- a/doc/source/example-policy-gradient.rst +++ b/doc/source/example-policy-gradient.rst @@ -6,7 +6,7 @@ View the `code for this example`_. .. note:: - For an overview of Ray's reinforcement learning library, see `Ray RLlib `__. + For an overview of Ray's reinforcement learning library, see `RLlib `__. To run this example, you will need to install `TensorFlow with GPU support`_ (at diff --git a/doc/source/index.rst b/doc/source/index.rst index b71987108be0..d951066e8842 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -77,7 +77,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin .. toctree:: :maxdepth: 1 - :caption: Ray RLlib + :caption: RLlib rllib.rst rllib-training.rst diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 6de076785707..c95def692e29 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -50,7 +50,7 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c OpenAI Gym ---------- -RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. +RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. Performance ~~~~~~~~~~~ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index c279855ac89c..5b3f88cf0e36 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -46,7 +46,7 @@ Custom models should subclass the common RLlib `model class `__ and associated `training scripts `__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements. +For a full example of a custom model in code, see the `Carla RLlib model `__ and associated `training scripts `__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements. Custom Preprocessors -------------------- diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index ea5bbbf58381..ba011d08c45e 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -10,7 +10,7 @@ Learn more about RLlib's design by reading the `ICML paper `__ or `TensorFlow `__. Then, install the Ray RLlib module: +RLlib has extra dependencies on top of ``ray``. First, you'll need to install either `PyTorch `__ or `TensorFlow `__. Then, install the RLlib module: .. code-block:: bash diff --git a/examples/carla/scenarios.py b/examples/carla/scenarios.py deleted file mode 100644 index e6494af1830d..000000000000 --- a/examples/carla/scenarios.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" - - -TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] -TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] - - -def build_scenario( - city, start, end, vehicles, pedestrians, max_steps, weathers): - return { - "city": city, - "num_vehicles": vehicles, - "num_pedestrians": pedestrians, - "weather_distribution": weathers, - "start_pos_id": start, - "end_pos_id": end, - "max_steps": max_steps, - } - - -# Simple scenario for Town02 that involves driving down a road -DEFAULT_SCENARIO = build_scenario( - city="Town02", start=36, end=40, vehicles=20, pedestrians=40, - max_steps=200, weathers=[0]) - -# Simple scenario for Town02 that involves driving down a road -LANE_KEEP = build_scenario( - city="Town02", start=36, end=40, vehicles=0, pedestrians=0, - max_steps=2000, weathers=[0]) - -# Scenarios from the CoRL2017 paper -POSES_TOWN1_STRAIGHT = [ - [36, 40], [39, 35], [110, 114], [7, 3], [0, 4], - [68, 50], [61, 59], [47, 64], [147, 90], [33, 87], - [26, 19], [80, 76], [45, 49], [55, 44], [29, 107], - [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], - [20, 107], [78, 70], [95, 102], [68, 44], [45, 69]] - - -POSES_TOWN1_ONE_CURVE = [ - [138, 17], [47, 16], [26, 9], [42, 49], [140, 124], - [85, 98], [65, 133], [137, 51], [76, 66], [46, 39], - [40, 60], [0, 29], [4, 129], [121, 140], [2, 129], - [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], - [84, 69], [47, 79], [110, 15], [130, 17], [0, 17]] - -POSES_TOWN1_NAV = [ - [105, 29], [27, 130], [102, 87], [132, 27], [24, 44], - [96, 26], [34, 67], [28, 1], [140, 134], [105, 9], - [148, 129], [65, 18], [21, 16], [147, 97], [42, 51], - [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], - [111, 64], [79, 45], [84, 69], [73, 31], [37, 81]] - - -POSES_TOWN2_STRAIGHT = [ - [38, 34], [4, 2], [12, 10], [62, 55], [43, 47], - [64, 66], [78, 76], [59, 57], [61, 18], [35, 39], - [12, 8], [0, 18], [75, 68], [54, 60], [45, 49], - [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], - [54, 63], [51, 42], [16, 19], [17, 26], [77, 68]] - -POSES_TOWN2_ONE_CURVE = [ - [37, 76], [8, 24], [60, 69], [38, 10], [21, 1], - [58, 71], [74, 32], [44, 0], [71, 16], [14, 24], - [34, 11], [43, 14], [75, 16], [80, 21], [3, 23], - [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], - [40, 63], [58, 76], [79, 55], [16, 61], [27, 11]] - -POSES_TOWN2_NAV = [ - [19, 66], [79, 14], [19, 57], [23, 1], - [53, 76], [42, 13], [31, 71], [33, 5], - [54, 30], [10, 61], [66, 3], [27, 12], - [79, 19], [2, 29], [16, 14], [5, 57], - [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], - [51, 81], [77, 68], [56, 65], [43, 54]] - -TOWN1_STRAIGHT = [ - build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_STRAIGHT] - -TOWN1_ONE_CURVE = [ - build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_ONE_CURVE] - -TOWN1_NAVIGATION = [ - build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV] - -TOWN1_NAVIGATION_DYNAMIC = [ - build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV] - -TOWN2_STRAIGHT = [ - build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT] - -TOWN2_STRAIGHT_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT] - -TOWN2_ONE_CURVE = [ - build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_ONE_CURVE] - -TOWN2_NAVIGATION = [ - build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV] - -TOWN2_NAVIGATION_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV] - -TOWN1_ALL = ( - TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + - TOWN1_NAVIGATION_DYNAMIC) - -TOWN2_ALL = ( - TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + - TOWN2_NAVIGATION_DYNAMIC) diff --git a/examples/custom_env/README b/examples/custom_env/README deleted file mode 100644 index 75ffcad88fb3..000000000000 --- a/examples/custom_env/README +++ /dev/null @@ -1 +0,0 @@ -Example of using a custom gym env with RLlib. diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index e1a945985771..5984e2e01882 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -25,19 +25,17 @@ ]) DEFAULT_CONFIG = with_common_config({ - 'noise_stdev': 0.02, # std deviation of parameter noise - 'num_rollouts': 32, # number of perturbs to try - 'rollouts_used': 32, # number of perturbs to keep in gradient estimate - 'num_workers': 2, - 'sgd_stepsize': 0.01, # sgd step-size - 'observation_filter': "MeanStdFilter", - 'noise_size': 250000000, - 'eval_prob': 0.03, # probability of evaluating the parameter rewards - 'report_length': 10, # how many of the last rewards we average over - 'env_config': {}, - 'offset': 0, - 'policy_type': "LinearPolicy", # ["LinearPolicy", "MLPPolicy"] - "fcnet_hiddens": [32, 32], # fcnet structure of MLPPolicy + "noise_stdev": 0.02, # std deviation of parameter noise + "num_rollouts": 32, # number of perturbs to try + "rollouts_used": 32, # number of perturbs to keep in gradient estimate + "num_workers": 2, + "sgd_stepsize": 0.01, # sgd step-size + "observation_filter": "MeanStdFilter", + "noise_size": 250000000, + "eval_prob": 0.03, # probability of evaluating the parameter rewards + "report_length": 10, # how many of the last rewards we average over + "env_config": {}, + "offset": 0, }) @@ -67,15 +65,9 @@ def get_delta(self, dim): @ray.remote class Worker(object): - def __init__(self, - config, - policy_params, - env_creator, - noise, - min_task_runtime=0.2): + def __init__(self, config, env_creator, noise, min_task_runtime=0.2): self.min_task_runtime = min_task_runtime self.config = config - self.policy_params = policy_params self.noise = SharedNoiseTable(noise) self.env = env_creator(config["env_config"]) @@ -83,15 +75,9 @@ def __init__(self, self.preprocessor = models.ModelCatalog.get_preprocessor(self.env) self.sess = utils.make_session(single_threaded=True) - if config["policy_type"] == "LinearPolicy": - self.policy = policies.LinearPolicy( - self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], **policy_params) - else: - self.policy = policies.MLPPolicy( - self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], config["fcnet_hiddens"], - **policy_params) + self.policy = policies.GenericPolicy( + self.sess, self.env.action_space, self.preprocessor, + config["observation_filter"], config["model"]) def rollout(self, timestep_limit, add_noise=False): rollout_rewards, rollout_length = policies.rollout( @@ -160,25 +146,14 @@ def default_resource_request(cls, config): return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"]) def _init(self): - policy_params = {"action_noise_std": 0.0} - - # register the linear network - utils.register_linear_network() - env = self.env_creator(self.config["env_config"]) from ray.rllib import models preprocessor = models.ModelCatalog.get_preprocessor(env) self.sess = utils.make_session(single_threaded=False) - if self.config["policy_type"] == "LinearPolicy": - self.policy = policies.LinearPolicy( - self.sess, env.action_space, preprocessor, - self.config["observation_filter"], **policy_params) - else: - self.policy = policies.MLPPolicy( - self.sess, env.action_space, preprocessor, - self.config["observation_filter"], - self.config["fcnet_hiddens"], **policy_params) + self.policy = policies.GenericPolicy( + self.sess, env.action_space, preprocessor, + self.config["observation_filter"], self.config["model"]) self.optimizer = optimizers.SGD(self.policy, self.config["sgd_stepsize"]) @@ -194,8 +169,8 @@ def _init(self): # Create the actors. print("Creating actors.") self.workers = [ - Worker.remote(self.config, policy_params, self.env_creator, - noise_id) for _ in range(self.config["num_workers"]) + Worker.remote(self.config, self.env_creator, noise_id) + for _ in range(self.config["num_workers"]) ] self.episodes_so_far = 0 diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index 3a25d68eb6b3..6c8bd9273801 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -11,7 +11,6 @@ import ray from ray.rllib.utils.filter import get_filter -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.models import ModelCatalog @@ -59,14 +58,8 @@ def __init__(self, action_space, preprocessor, observation_filter, - action_noise_std, - options={}): - - if len(preprocessor.shape) > 1: - raise UnsupportedSpaceException( - "Observation space {} is not supported with ARS.".format( - preprocessor.shape)) - + model_config, + action_noise_std=0.0): self.sess = sess self.action_space = action_space self.action_noise_std = action_noise_std @@ -78,9 +71,9 @@ def __init__(self, # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( - action_space, dist_type="deterministic") + action_space, model_config, dist_type="deterministic") - model = ModelCatalog.get_model(self.inputs, dist_dim, options=options) + model = ModelCatalog.get_model(self.inputs, dist_dim, model_config) dist = dist_class(model.outputs) self.sampler = dist.sample() @@ -106,31 +99,3 @@ def set_weights(self, x): def get_weights(self): return self.variables.get_flat() - - -class LinearPolicy(GenericPolicy): - def __init__(self, sess, action_space, preprocessor, observation_filter, - action_noise_std): - options = {"custom_model": "LinearNetwork"} - GenericPolicy.__init__( - self, - sess, - action_space, - preprocessor, - observation_filter, - action_noise_std, - options=options) - - -class MLPPolicy(GenericPolicy): - def __init__(self, sess, action_space, preprocessor, observation_filter, - fcnet_hiddens, action_noise_std): - options = {"fcnet_hiddens": fcnet_hiddens} - GenericPolicy.__init__( - self, - sess, - action_space, - preprocessor, - observation_filter, - action_noise_std, - options=options) diff --git a/python/ray/rllib/agents/ars/utils.py b/python/ray/rllib/agents/ars/utils.py index a70dd97bb61a..1575e46c3837 100644 --- a/python/ray/rllib/agents/ars/utils.py +++ b/python/ray/rllib/agents/ars/utils.py @@ -7,9 +7,6 @@ import numpy as np import tensorflow as tf -from ray.rllib.models import ModelCatalog, Model -import tensorflow.contrib.slim as slim -from ray.rllib.models.misc import normc_initializer def compute_ranks(x): @@ -62,21 +59,3 @@ def batched_weighted_sum(weights, vecs, batch_size): np.asarray(batch_vecs, dtype=np.float32)) num_items_summed += len(batch_weights) return total, num_items_summed - - -class LinearNetwork(Model): - """Generic linear network.""" - - def _build_layers(self, inputs, num_outputs, _): - with tf.name_scope("linear"): - output = slim.fully_connected( - inputs, - num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - ) - return output, inputs - - -def register_linear_network(): - ModelCatalog.register_custom_model("LinearNetwork", LinearNetwork) diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 1ce219b7c0ab..392f98f1d8f2 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -10,7 +10,7 @@ import time import ray -from ray.rllib.agents import Agent +from ray.rllib.agents import Agent, with_common_config from ray.tune.trial import Resources from ray.rllib.agents.es import optimizers @@ -24,7 +24,7 @@ "eval_returns", "eval_lengths" ]) -DEFAULT_CONFIG = { +DEFAULT_CONFIG = with_common_config({ "l2_coeff": 0.005, "noise_stdev": 0.02, "episodes_per_batch": 1000, @@ -38,7 +38,8 @@ "report_length": 10, "env": None, "env_config": {}, -} + "model": {}, +}) @ray.remote @@ -81,7 +82,7 @@ def __init__(self, self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( self.sess, self.env.action_space, self.preprocessor, - config["observation_filter"], **policy_params) + config["observation_filter"], config["model"], **policy_params) def rollout(self, timestep_limit, add_noise=True): rollout_rewards, rollout_length = policies.rollout( @@ -161,7 +162,8 @@ def _init(self): self.sess = utils.make_session(single_threaded=False) self.policy = policies.GenericPolicy( self.sess, env.action_space, preprocessor, - self.config["observation_filter"], **policy_params) + self.config["observation_filter"], self.config["model"], + **policy_params) self.optimizer = optimizers.Adam(self.policy, self.config["stepsize"]) self.report_length = self.config["report_length"] diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index d62fee43c4c5..b40f2db5698c 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -39,7 +39,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False): class GenericPolicy(object): def __init__(self, sess, action_space, preprocessor, observation_filter, - action_noise_std): + model_options, action_noise_std): self.sess = sess self.action_space = action_space self.action_noise_std = action_noise_std @@ -51,8 +51,8 @@ def __init__(self, sess, action_space, preprocessor, observation_filter, # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( - self.action_space, dist_type="deterministic") - model = ModelCatalog.get_model(self.inputs, dist_dim) + self.action_space, model_options, dist_type="deterministic") + model = ModelCatalog.get_model(self.inputs, dist_dim, model_options) dist = dist_class(model.outputs) self.sampler = dist.sample() diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index bb831c47d4ee..7cdb8532b7c2 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -24,8 +24,8 @@ def __init__(self, obs_space, action_space, config): obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape)) dist_class, self.logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) - self.model = ModelCatalog.get_model( - obs, self.logit_dim, options=self.config["model"]) + self.model = ModelCatalog.get_model(obs, self.logit_dim, + self.config["model"]) action_dist = dist_class(self.model.outputs) # logit for each action # Setup policy loss diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index e6fc90d1ce94..9456ebe944cc 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -54,7 +54,7 @@ def __init__(self, vf_loss_coeff (float): Coefficient of the value function loss use_gae (bool): If true, use the Generalized Advantage Estimator. """ - dist_cls, _ = ModelCatalog.get_action_dist(action_space) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( @@ -108,7 +108,8 @@ def __init__(self, self.config = config self.kl_coeff_val = self.config["kl_coeff"] self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist(action_space) + dist_cls, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) if existing_inputs: obs_ph, value_targets_ph, adv_ph, act_ph, \ diff --git a/examples/carla/README b/python/ray/rllib/examples/carla/README similarity index 100% rename from examples/carla/README rename to python/ray/rllib/examples/carla/README diff --git a/examples/carla/a3c_lane_keep.py b/python/ray/rllib/examples/carla/a3c_lane_keep.py similarity index 96% rename from examples/carla/a3c_lane_keep.py rename to python/ray/rllib/examples/carla/a3c_lane_keep.py index 1338736d23f5..9629808ba4c7 100644 --- a/examples/carla/a3c_lane_keep.py +++ b/python/ray/rllib/examples/carla/a3c_lane_keep.py @@ -31,7 +31,6 @@ "carla-a3c": { "run": "A3C", "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, "config": { "env_config": env_config, "model": { diff --git a/examples/carla/dqn_lane_keep.py b/python/ray/rllib/examples/carla/dqn_lane_keep.py similarity index 90% rename from examples/carla/dqn_lane_keep.py rename to python/ray/rllib/examples/carla/dqn_lane_keep.py index 2746a1c4bbd8..84fed98cd5f9 100644 --- a/examples/carla/dqn_lane_keep.py +++ b/python/ray/rllib/examples/carla/dqn_lane_keep.py @@ -31,7 +31,6 @@ "carla-dqn": { "run": "DQN", "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, "config": { "env_config": env_config, "model": { @@ -49,9 +48,6 @@ "learning_starts": 1000, "schedule_max_timesteps": 100000, "gamma": 0.8, - "tf_session_args": { - "gpu_options": {"allow_growth": True}, - }, }, }, }) diff --git a/examples/carla/env.py b/python/ray/rllib/examples/carla/env.py similarity index 83% rename from examples/carla/env.py rename to python/ray/rllib/examples/carla/env.py index c88a71b28f51..af5b619afcdb 100644 --- a/examples/carla/env.py +++ b/python/ray/rllib/examples/carla/env.py @@ -33,8 +33,8 @@ os.makedirs(CARLA_OUT_PATH) # Set this to the path of your Carla binary -SERVER_BINARY = os.environ.get( - "CARLA_SERVER", os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) +SERVER_BINARY = os.environ.get("CARLA_SERVER", + os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) assert os.path.exists(SERVER_BINARY) if "CARLA_PY_PATH" in os.environ: @@ -97,7 +97,6 @@ "squash_action_logits": False, } - DISCRETE_ACTIONS = { # coast 0: [0.0, 0.0], @@ -119,7 +118,6 @@ 8: [-0.5, 0.5], } - live_carla_processes = set() @@ -133,7 +131,6 @@ def cleanup(): class CarlaEnv(gym.Env): - def __init__(self, config=ENV_CONFIG): self.config = config self.city = self.config["server_map"].split("/")[-1] @@ -143,21 +140,27 @@ def __init__(self, config=ENV_CONFIG): if config["discrete_actions"]: self.action_space = Discrete(len(DISCRETE_ACTIONS)) else: - self.action_space = Box(-1.0, 1.0, shape=(2,), dtype=np.float32) + self.action_space = Box(-1.0, 1.0, shape=(2, ), dtype=np.float32) if config["use_depth_camera"]: image_space = Box( - -1.0, 1.0, shape=( - config["y_res"], config["x_res"], - 1 * config["framestack"]), dtype=np.float32) + -1.0, + 1.0, + shape=(config["y_res"], config["x_res"], + 1 * config["framestack"]), + dtype=np.float32) else: image_space = Box( - 0, 255, shape=( - config["y_res"], config["x_res"], - 3 * config["framestack"]), dtype=np.uint8) + 0, + 255, + shape=(config["y_res"], config["x_res"], + 3 * config["framestack"]), + dtype=np.uint8) self.observation_space = Tuple( # forward_speed, dist to goal - [image_space, - Discrete(len(COMMANDS_ENUM)), # next_command - Box(-128.0, 128.0, shape=(2,), dtype=np.float32)]) + [ + image_space, + Discrete(len(COMMANDS_ENUM)), # next_command + Box(-128.0, 128.0, shape=(2, ), dtype=np.float32) + ]) # TODO(ekl) this isn't really a proper gym spec self._spec = lambda: None @@ -185,11 +188,13 @@ def init_server(self): # Create a new server process and start the client. self.server_port = random.randint(10000, 60000) self.server_process = subprocess.Popen( - [SERVER_BINARY, self.config["server_map"], - "-windowed", "-ResX=400", "-ResY=300", - "-carla-server", - "-carla-world-port={}".format(self.server_port)], - preexec_fn=os.setsid, stdout=open(os.devnull, "w")) + [ + SERVER_BINARY, self.config["server_map"], "-windowed", + "-ResX=400", "-ResY=300", "-carla-server", + "-carla-world-port={}".format(self.server_port) + ], + preexec_fn=os.setsid, + stdout=open(os.devnull, "w")) live_carla_processes.add(os.getpgid(self.server_process.pid)) for i in range(RETRIES_ON_ERROR): @@ -257,14 +262,14 @@ def _reset(self): if self.config["use_depth_camera"]: camera1 = Camera("CameraDepth", PostProcessing="Depth") - camera1.set_image_size( - self.config["render_x_res"], self.config["render_y_res"]) + camera1.set_image_size(self.config["render_x_res"], + self.config["render_y_res"]) camera1.set_position(30, 0, 130) settings.add_sensor(camera1) camera2 = Camera("CameraRGB") - camera2.set_image_size( - self.config["render_x_res"], self.config["render_y_res"]) + camera2.set_image_size(self.config["render_x_res"], + self.config["render_y_res"]) camera2.set_position(30, 0, 130) settings.add_sensor(camera2) @@ -274,13 +279,14 @@ def _reset(self): self.start_pos = positions[self.scenario["start_pos_id"]] self.end_pos = positions[self.scenario["end_pos_id"]] self.start_coord = [ - self.start_pos.location.x // 100, self.start_pos.location.y // 100] + self.start_pos.location.x // 100, self.start_pos.location.y // 100 + ] self.end_coord = [ - self.end_pos.location.x // 100, self.end_pos.location.y // 100] - print( - "Start pos {} ({}), end {} ({})".format( - self.scenario["start_pos_id"], self.start_coord, - self.scenario["end_pos_id"], self.end_coord)) + self.end_pos.location.x // 100, self.end_pos.location.y // 100 + ] + print("Start pos {} ({}), end {} ({})".format( + self.scenario["start_pos_id"], self.start_coord, + self.scenario["end_pos_id"], self.end_coord)) # Notify the server that we want to start the episode at the # player_start index. This function blocks until the server is ready @@ -300,11 +306,10 @@ def encode_obs(self, image, py_measurements): prev_image = image if self.config["framestack"] == 2: image = np.concatenate([prev_image, image], axis=2) - obs = ( - image, - COMMAND_ORDINAL[py_measurements["next_command"]], - [py_measurements["forward_speed"], - py_measurements["distance_to_goal"]]) + obs = (image, COMMAND_ORDINAL[py_measurements["next_command"]], [ + py_measurements["forward_speed"], + py_measurements["distance_to_goal"] + ]) self.last_obs = obs return obs @@ -313,9 +318,8 @@ def step(self, action): obs = self._step(action) return obs except Exception: - print( - "Error during step, terminating episode early", - traceback.format_exc()) + print("Error during step, terminating episode early", + traceback.format_exc()) self.clear_server_state() return (self.last_obs, 0.0, True, {}) @@ -336,12 +340,14 @@ def _step(self, action): hand_brake = False if self.config["verbose"]: - print( - "steer", steer, "throttle", throttle, "brake", brake, - "reverse", reverse) + print("steer", steer, "throttle", throttle, "brake", brake, + "reverse", reverse) self.client.send_control( - steer=steer, throttle=throttle, brake=brake, hand_brake=hand_brake, + steer=steer, + throttle=throttle, + brake=brake, + hand_brake=hand_brake, reverse=reverse) # Process observations @@ -359,15 +365,14 @@ def _step(self, action): "reverse": reverse, "hand_brake": hand_brake, } - reward = compute_reward( - self, self.prev_measurement, py_measurements) + reward = compute_reward(self, self.prev_measurement, py_measurements) self.total_reward += reward py_measurements["reward"] = reward py_measurements["total_reward"] = self.total_reward - done = (self.num_steps > self.scenario["max_steps"] or - py_measurements["next_command"] == "REACH_GOAL" or - (self.config["early_terminate_on_collision"] and - collided_done(py_measurements))) + done = (self.num_steps > self.scenario["max_steps"] + or py_measurements["next_command"] == "REACH_GOAL" + or (self.config["early_terminate_on_collision"] + and collided_done(py_measurements))) py_measurements["done"] = done self.prev_measurement = py_measurements @@ -377,8 +382,7 @@ def _step(self, action): self.measurements_file = open( os.path.join( CARLA_OUT_PATH, - "measurements_{}.json".format(self.episode_id)), - "w") + "measurements_{}.json".format(self.episode_id)), "w") self.measurements_file.write(json.dumps(py_measurements)) self.measurements_file.write("\n") if done: @@ -389,9 +393,8 @@ def _step(self, action): self.num_steps += 1 image = self.preprocess_image(image) - return ( - self.encode_obs(image, py_measurements), reward, done, - py_measurements) + return (self.encode_obs(image, py_measurements), reward, done, + py_measurements) def images_to_video(self): videos_dir = os.path.join(CARLA_OUT_PATH, "Videos") @@ -413,15 +416,15 @@ def preprocess_image(self, image): if self.config["use_depth_camera"]: assert self.config["use_depth_camera"] data = (image.data - 0.5) * 2 - data = data.reshape( - self.config["render_y_res"], self.config["render_x_res"], 1) + data = data.reshape(self.config["render_y_res"], + self.config["render_x_res"], 1) data = cv2.resize( data, (self.config["x_res"], self.config["y_res"]), interpolation=cv2.INTER_AREA) data = np.expand_dims(data, 2) else: - data = image.data.reshape( - self.config["render_y_res"], self.config["render_x_res"], 3) + data = image.data.reshape(self.config["render_y_res"], + self.config["render_x_res"], 3) data = cv2.resize( data, (self.config["x_res"], self.config["y_res"]), interpolation=cv2.INTER_AREA) @@ -448,36 +451,39 @@ def _read_observation(self): cur = measurements.player_measurements if self.config["enable_planner"]: - next_command = COMMANDS_ENUM[ - self.planner.get_next_command( - [cur.transform.location.x, cur.transform.location.y, - GROUND_Z], - [cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z], - [self.end_pos.location.x, self.end_pos.location.y, - GROUND_Z], - [self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z]) - ] + next_command = COMMANDS_ENUM[self.planner.get_next_command( + [cur.transform.location.x, cur.transform.location.y, GROUND_Z], + [ + cur.transform.orientation.x, cur.transform.orientation.y, + GROUND_Z + ], + [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ + self.end_pos.orientation.x, self.end_pos.orientation.y, + GROUND_Z + ])] else: next_command = "LANE_FOLLOW" if next_command == "REACH_GOAL": distance_to_goal = 0.0 # avoids crash in planner elif self.config["enable_planner"]: - distance_to_goal = self.planner.get_shortest_path_distance( - [cur.transform.location.x, cur.transform.location.y, GROUND_Z], - [cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z], - [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], - [self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z]) / 100 + distance_to_goal = self.planner.get_shortest_path_distance([ + cur.transform.location.x, cur.transform.location.y, GROUND_Z + ], [ + cur.transform.orientation.x, cur.transform.orientation.y, + GROUND_Z + ], [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ + self.end_pos.orientation.x, self.end_pos.orientation.y, + GROUND_Z + ]) / 100 else: distance_to_goal = -1 - distance_to_goal_euclidean = float(np.linalg.norm( - [cur.transform.location.x - self.end_pos.location.x, - cur.transform.location.y - self.end_pos.location.y]) / 100) + distance_to_goal_euclidean = float( + np.linalg.norm([ + cur.transform.location.x - self.end_pos.location.x, + cur.transform.location.y - self.end_pos.location.y + ]) / 100) py_measurements = { "episode_id": self.episode_id, @@ -513,8 +519,8 @@ def _read_observation(self): if not os.path.exists(out_dir): os.makedirs(out_dir) out_file = os.path.join( - out_dir, - "{}_{:>04}.jpg".format(self.episode_id, self.num_steps)) + out_dir, "{}_{:>04}.jpg".format(self.episode_id, + self.num_steps)) scipy.misc.imsave(out_file, image.data) assert observation is not None, sensor_data @@ -621,8 +627,7 @@ def compute_reward_lane_keep(env, prev, current): def compute_reward(env, prev, current): - return REWARD_FUNCTIONS[env.config["reward_function"]]( - env, prev, current) + return REWARD_FUNCTIONS[env.config["reward_function"]](env, prev, current) def print_measurements(measurements): @@ -654,9 +659,8 @@ def sigmoid(x): def collided_done(py_measurements): m = py_measurements - collided = ( - m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 or - m["collision_other"] > 0) + collided = (m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 + or m["collision_other"] > 0) return bool(collided or m["total_reward"] < -100) diff --git a/examples/carla/models.py b/python/ray/rllib/examples/carla/models.py similarity index 83% rename from examples/carla/models.py rename to python/ray/rllib/examples/carla/models.py index 9233c9c8ed2b..fd20cd0c000c 100644 --- a/examples/carla/models.py +++ b/python/ray/rllib/examples/carla/models.py @@ -43,8 +43,8 @@ def _build_layers(self, inputs, num_outputs, options): (inputs.shape.as_list()[1:], expected_shape) # Reshape the input vector back into its components - vision_in = tf.reshape( - inputs[:, :image_size], [tf.shape(inputs)[0]] + image_shape) + vision_in = tf.reshape(inputs[:, :image_size], + [tf.shape(inputs)[0]] + image_shape) metrics_in = inputs[:, image_size:] print("Vision in shape", vision_in) print("Metrics in shape", metrics_in) @@ -53,18 +53,26 @@ def _build_layers(self, inputs, num_outputs, options): with tf.name_scope("carla_vision"): for i, (out_size, kernel, stride) in enumerate(convs[:-1], 1): vision_in = slim.conv2d( - vision_in, out_size, kernel, stride, + vision_in, + out_size, + kernel, + stride, scope="conv{}".format(i)) out_size, kernel, stride = convs[-1] vision_in = slim.conv2d( - vision_in, out_size, kernel, stride, - padding="VALID", scope="conv_out") + vision_in, + out_size, + kernel, + stride, + padding="VALID", + scope="conv_out") vision_in = tf.squeeze(vision_in, [1, 2]) # Setup metrics layer with tf.name_scope("carla_metrics"): metrics_in = slim.fully_connected( - metrics_in, 64, + metrics_in, + 64, weights_initializer=xavier_initializer(), activation_fn=activation, scope="metrics_out") @@ -79,15 +87,18 @@ def _build_layers(self, inputs, num_outputs, options): print("Shape of concatenated out is", last_layer.shape) for size in hiddens: last_layer = slim.fully_connected( - last_layer, size, + last_layer, + size, weights_initializer=xavier_initializer(), activation_fn=activation, scope="fc{}".format(i)) i += 1 output = slim.fully_connected( - last_layer, num_outputs, + last_layer, + num_outputs, weights_initializer=normc_initializer(0.01), - activation_fn=None, scope="fc_out") + activation_fn=None, + scope="fc_out") return output, last_layer diff --git a/examples/carla/ppo_lane_keep.py b/python/ray/rllib/examples/carla/ppo_lane_keep.py similarity index 93% rename from examples/carla/ppo_lane_keep.py rename to python/ray/rllib/examples/carla/ppo_lane_keep.py index 25e5acbf328c..ac0f6ff8aff0 100644 --- a/examples/carla/ppo_lane_keep.py +++ b/python/ray/rllib/examples/carla/ppo_lane_keep.py @@ -31,7 +31,6 @@ "carla-ppo": { "run": "PPO", "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, "config": { "env_config": env_config, "model": { @@ -55,7 +54,9 @@ "sgd_batchsize": 32, "devices": ["/gpu:0"], "tf_session_args": { - "gpu_options": {"allow_growth": True} + "gpu_options": { + "allow_growth": True + } } }, }, diff --git a/python/ray/rllib/examples/carla/scenarios.py b/python/ray/rllib/examples/carla/scenarios.py new file mode 100644 index 000000000000..beedd2989d5c --- /dev/null +++ b/python/ray/rllib/examples/carla/scenarios.py @@ -0,0 +1,131 @@ +"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" + +TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] +TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] + + +def build_scenario(city, start, end, vehicles, pedestrians, max_steps, + weathers): + return { + "city": city, + "num_vehicles": vehicles, + "num_pedestrians": pedestrians, + "weather_distribution": weathers, + "start_pos_id": start, + "end_pos_id": end, + "max_steps": max_steps, + } + + +# Simple scenario for Town02 that involves driving down a road +DEFAULT_SCENARIO = build_scenario( + city="Town02", + start=36, + end=40, + vehicles=20, + pedestrians=40, + max_steps=200, + weathers=[0]) + +# Simple scenario for Town02 that involves driving down a road +LANE_KEEP = build_scenario( + city="Town02", + start=36, + end=40, + vehicles=0, + pedestrians=0, + max_steps=2000, + weathers=[0]) + +# Scenarios from the CoRL2017 paper +POSES_TOWN1_STRAIGHT = [[36, 40], [39, 35], [110, 114], [7, 3], [0, 4], [ + 68, 50 +], [61, 59], [47, 64], [147, 90], [33, 87], [26, 19], [80, 76], [45, 49], [ + 55, 44 +], [29, 107], [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], [20, 107], + [78, 70], [95, 102], [68, 44], [45, 69]] + +POSES_TOWN1_ONE_CURVE = [[138, 17], [47, 16], [26, 9], [42, 49], [140, 124], [ + 85, 98 +], [65, 133], [137, 51], [76, 66], [46, 39], [40, 60], [0, 29], [4, 129], [ + 121, 140 +], [2, 129], [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], [84, 69], + [47, 79], [110, 15], [130, 17], [0, 17]] + +POSES_TOWN1_NAV = [[105, 29], [27, 130], [102, 87], [132, 27], [24, 44], [ + 96, 26 +], [34, 67], [28, 1], [140, 134], [105, 9], [148, 129], [65, 18], [21, 16], [ + 147, 97 +], [42, 51], [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], [111, 64], + [79, 45], [84, 69], [73, 31], [37, 81]] + +POSES_TOWN2_STRAIGHT = [[38, 34], [4, 2], [12, 10], [62, 55], [43, 47], [ + 64, 66 +], [78, 76], [59, 57], [61, 18], [35, 39], [12, 8], [0, 18], [75, 68], [ + 54, 60 +], [45, 49], [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], [54, 63], + [51, 42], [16, 19], [17, 26], [77, 68]] + +POSES_TOWN2_ONE_CURVE = [[37, 76], [8, 24], [60, 69], [38, 10], [21, 1], [ + 58, 71 +], [74, 32], [44, 0], [71, 16], [14, 24], [34, 11], [43, 14], [75, 16], [ + 80, 21 +], [3, 23], [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], [40, 63], + [58, 76], [79, 55], [16, 61], [27, 11]] + +POSES_TOWN2_NAV = [[19, 66], [79, 14], [19, 57], [23, 1], [53, 76], [42, 13], [ + 31, 71 +], [33, 5], [54, 30], [10, 61], [66, 3], [27, 12], [79, 19], [2, 29], [16, 14], + [5, 57], [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], + [51, 81], [77, 68], [56, 65], [43, 54]] + +TOWN1_STRAIGHT = [ + build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_STRAIGHT +] + +TOWN1_ONE_CURVE = [ + build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_ONE_CURVE +] + +TOWN1_NAVIGATION = [ + build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_NAV +] + +TOWN1_NAVIGATION_DYNAMIC = [ + build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) + for (start, end) in POSES_TOWN1_NAV +] + +TOWN2_STRAIGHT = [ + build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_STRAIGHT +] + +TOWN2_STRAIGHT_DYNAMIC = [ + build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_STRAIGHT +] + +TOWN2_ONE_CURVE = [ + build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_ONE_CURVE +] + +TOWN2_NAVIGATION = [ + build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_NAV +] + +TOWN2_NAVIGATION_DYNAMIC = [ + build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) + for (start, end) in POSES_TOWN2_NAV +] + +TOWN1_ALL = (TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + + TOWN1_NAVIGATION_DYNAMIC) + +TOWN2_ALL = (TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + + TOWN2_NAVIGATION_DYNAMIC) diff --git a/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py similarity index 96% rename from examples/carla/train_a3c.py rename to python/ray/rllib/examples/carla/train_a3c.py index 75856aef266e..2c12cd8245cf 100644 --- a/examples/carla/train_a3c.py +++ b/python/ray/rllib/examples/carla/train_a3c.py @@ -32,7 +32,6 @@ "carla-a3c": { "run": "A3C", "env": "carla_env", - "trial_resources": {"cpu": 5, "extra_gpu": 2}, "config": { "env_config": env_config, "use_gpu_for_workers": True, diff --git a/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py similarity index 81% rename from examples/carla/train_dqn.py rename to python/ray/rllib/examples/carla/train_dqn.py index 6180ca48f0dd..fa2dba1053aa 100644 --- a/examples/carla/train_dqn.py +++ b/python/ray/rllib/examples/carla/train_dqn.py @@ -25,21 +25,26 @@ register_carla_model() ray.init() + + +def shape_out(spec): + return (spec.config.env_config.framestack * + (spec.config.env_config.use_depth_camera and 1 or 3)) + + run_experiments({ "carla-dqn": { "run": "DQN", "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, "config": { "env_config": env_config, "model": { "custom_model": "carla", "custom_options": { "image_shape": [ - 80, 80, - lambda spec: spec.config.env_config.framestack * ( - spec.config.env_config.use_depth_camera and 1 or 3 - ), + 80, + 80, + shape_out, ], }, "conv_filters": [ @@ -53,7 +58,9 @@ "schedule_max_timesteps": 100000, "gamma": 0.8, "tf_session_args": { - "gpu_options": {"allow_growth": True}, + "gpu_options": { + "allow_growth": True + }, }, }, }, diff --git a/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py similarity index 80% rename from examples/carla/train_ppo.py rename to python/ray/rllib/examples/carla/train_ppo.py index 4f3ebf5eab83..a9339ca79481 100644 --- a/examples/carla/train_ppo.py +++ b/python/ray/rllib/examples/carla/train_ppo.py @@ -28,14 +28,14 @@ "carla": { "run": "PPO", "env": "carla_env", - "trial_resources": {"cpu": 4, "gpu": 1}, "config": { "env_config": env_config, "model": { "custom_model": "carla", "custom_options": { "image_shape": [ - env_config["x_res"], env_config["y_res"], 6], + env_config["x_res"], env_config["y_res"], 6 + ], }, "conv_filters": [ [16, [8, 8], 4], @@ -44,17 +44,14 @@ ], }, "num_workers": 1, - "timesteps_per_batch": 2000, - "min_steps_per_task": 100, + "train_batch_size": 2000, + "sample_batch_size": 100, "lambda": 0.95, "clip_param": 0.2, "num_sgd_iter": 20, - "sgd_stepsize": 0.0001, - "sgd_batchsize": 32, - "devices": ["/gpu:0"], - "tf_session_args": { - "gpu_options": {"allow_growth": True} - } + "lr": 0.0001, + "sgd_minibatch_size": 32, + "num_gpus": 1, }, }, }) diff --git a/examples/custom_env/custom_env.py b/python/ray/rllib/examples/custom_env.py similarity index 93% rename from examples/custom_env/custom_env.py rename to python/ray/rllib/examples/custom_env.py index b5a3240eaad0..66c0288081f9 100644 --- a/examples/custom_env/custom_env.py +++ b/python/ray/rllib/examples/custom_env.py @@ -24,7 +24,7 @@ def __init__(self, config): self.cur_pos = 0 self.action_space = Discrete(2) self.observation_space = Box( - 0.0, self.end_pos, shape=(1,), dtype=np.float32) + 0.0, self.end_pos, shape=(1, ), dtype=np.float32) self._spec = EnvSpec("SimpleCorridor-{}-v0".format(self.end_pos)) def reset(self): @@ -32,7 +32,7 @@ def reset(self): return [self.cur_pos] def step(self, action): - assert action in [0, 1] + assert action in [0, 1], action if action == 0 and self.cur_pos > 0: self.cur_pos -= 1 elif action == 1: diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index b98061fdd02a..9a889058cc6c 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -51,14 +51,15 @@ class ModelCatalog(object): >>> prep = ModelCatalog.get_preprocessor(env) >>> observation = prep.transform(raw_observation) - >>> dist_cls, dist_dim = ModelCatalog.get_action_dist(env.action_space) - >>> model = ModelCatalog.get_model(inputs, dist_dim) + >>> dist_cls, dist_dim = ModelCatalog.get_action_dist( + env.action_space, {}) + >>> model = ModelCatalog.get_model(inputs, dist_dim, options) >>> dist = dist_cls(model.outputs) >>> action = dist.sample() """ @staticmethod - def get_action_dist(action_space, config=None, dist_type=None): + def get_action_dist(action_space, config, dist_type=None): """Returns action distribution class and size for the given action space. Args: @@ -90,7 +91,8 @@ def get_action_dist(action_space, config=None, dist_type=None): child_dist = [] input_lens = [] for action in action_space.spaces: - dist, action_size = ModelCatalog.get_action_dist(action) + dist, action_size = ModelCatalog.get_action_dist( + action, config) child_dist.append(dist) input_lens.append(action_size) return partial( @@ -139,11 +141,7 @@ def get_action_placeholder(action_space): " not supported".format(action_space)) @staticmethod - def get_model(inputs, - num_outputs, - options=None, - state_in=None, - seq_lens=None): + def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None): """Returns a suitable model conforming to given input and output specs. Args: @@ -157,7 +155,6 @@ def get_model(inputs, model (Model): Neural network model. """ - options = options or {} model = ModelCatalog._get_model(inputs, num_outputs, options, state_in, seq_lens) diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index e3dc1e782535..62468e123bca 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -69,12 +69,13 @@ def testDefaultModels(self): ray.init() with tf.variable_scope("test1"): - p1 = ModelCatalog.get_model(np.zeros((10, 3), dtype=np.float32), 5) + p1 = ModelCatalog.get_model( + np.zeros((10, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model( - np.zeros((10, 84, 84, 3), dtype=np.float32), 5) + np.zeros((10, 84, 84, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): diff --git a/python/ray/rllib/tuned_examples/swimmer-ars.yaml b/python/ray/rllib/tuned_examples/swimmer-ars.yaml index 338c8a12c2cf..532bb00b0fa8 100644 --- a/python/ray/rllib/tuned_examples/swimmer-ars.yaml +++ b/python/ray/rllib/tuned_examples/swimmer-ars.yaml @@ -1,4 +1,3 @@ -# can expect improvement to -140 reward in ~300-500k timesteps swimmer-ars: env: Swimmer-v2 run: ARS @@ -9,8 +8,9 @@ swimmer-ars: num_workers: 1 sgd_stepsize: 0.02 noise_size: 250000000 - policy_type: LinearPolicy eval_prob: 0.2 offset: 0 observation_filter: NoFilter report_length: 3 + model: + fcnet_hiddens: [] # a linear policy From fcef4edd46795996389ec6dc0964adfcbdd11c52 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 2 Oct 2018 03:56:36 +0800 Subject: [PATCH 004/215] [Java] Fix the required-resources issue of actor member function in Java worker. (#3002) This fixes a bug in which Java actor methods inherit the resource requirements of the actor creation task. --- .../ray/runtime/functionmanager/RayFunction.java | 13 +++++++++---- .../functionmanager/FunctionManagerTest.java | 4 ++-- .../ray/api/test/ResourcesManagementTest.java | 16 ++++++++++++++++ .../main/java/org/ray/api/test/TestListener.java | 2 +- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java index 3d0704c6bf48..2f39ec3dc8db 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/RayFunction.java @@ -58,12 +58,17 @@ public FunctionDescriptor getFunctionDescriptor() { } public RayRemote getRayRemoteAnnotation() { - RayRemote rayRemote = executable.getAnnotation(RayRemote.class); - if (rayRemote == null) { - // If the method doesn't have a annotation, get the annotation from - // its wrapping class. + RayRemote rayRemote; + + // If this method is a constructor, the task of it should be a actorCreationTask. + // And the annotation of actorCreationTask should inherit from class. + // Otherwise, it's a normal method, and it shouldn't inherit annotation from class. + if (isConstructor()) { rayRemote = executable.getDeclaringClass().getAnnotation(RayRemote.class); + } else { + rayRemote = executable.getAnnotation(RayRemote.class); } + return rayRemote; } diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 85f482544c84..08e4d6415c54 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -74,7 +74,7 @@ public void testGetFunctionFromRayFunc() { func = functionManager.getFunction(UniqueId.NIL, barFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); - Assert.assertNotNull(func.getRayRemoteAnnotation()); + Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor func = functionManager.getFunction(UniqueId.NIL, barConstructor); @@ -95,7 +95,7 @@ public void testGetFunctionFromFunctionDescriptor() { func = functionManager.getFunction(UniqueId.NIL, barDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); - Assert.assertNotNull(func.getRayRemoteAnnotation()); + Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor func = functionManager.getFunction(UniqueId.NIL, barConstructorDescriptor); diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index 001723fec4bf..69d0f57a7570 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -45,6 +45,13 @@ public Integer echo(Integer number) { } } + @RayRemote(resources = {@ResourceItem(name = "RES-A", value = 4)}) + public static class Echo3 { + public Integer echo(Integer number) { + return number; + } + } + @Test public void testMethods() { // This is a case that can satisfy required resources. @@ -75,5 +82,14 @@ public void testActors() { Assert.assertEquals(1, waitResult.getUnready().size()); } + @Test + public void testActorAndMemberMethods() { + // Note(qwang): This case depends on the following line. + // https://github.com/ray-project/ray/blob/master/java/test/src/main/java/org/ray/api/test/TestListener.java#L13 + // If we change the static resources configuration item, this case may not pass. + // Then we should change this case too. + RayActor echo3 = Ray.createActor(Echo3::new); + Assert.assertEquals(100, (int) Ray.call(Echo3::echo, echo3, 100).get()); + } } diff --git a/java/test/src/main/java/org/ray/api/test/TestListener.java b/java/test/src/main/java/org/ray/api/test/TestListener.java index 3fb16bf4f379..efc419b34720 100644 --- a/java/test/src/main/java/org/ray/api/test/TestListener.java +++ b/java/test/src/main/java/org/ray/api/test/TestListener.java @@ -10,7 +10,7 @@ public class TestListener extends RunListener { @Override public void testRunStarted(Description description) { System.setProperty("ray.home", "../.."); - System.setProperty("ray.resources", "CPU:4"); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); Ray.init(); } From 2019b4122bd987202d6ea34ab19311e1771bd887 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 1 Oct 2018 13:07:11 -0700 Subject: [PATCH 005/215] [rllib] Remove legacy multiagent support (#2975) * remove legacy * remove reshaper --- .../examples/legacy_multiagent/__init__.py | 0 .../multiagent_mountaincar.py | 59 --------------- .../multiagent_mountaincar_env.py | 51 ------------- .../legacy_multiagent/multiagent_pendulum.py | 60 --------------- .../multiagent_pendulum_env.py | 74 ------------------- python/ray/rllib/models/catalog.py | 8 -- python/ray/rllib/models/multiagentfcnet.py | 43 ----------- python/ray/rllib/utils/reshaper.py | 49 ------------ 8 files changed, 344 deletions(-) delete mode 100644 python/ray/rllib/examples/legacy_multiagent/__init__.py delete mode 100644 python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py delete mode 100644 python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py delete mode 100644 python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py delete mode 100644 python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py delete mode 100644 python/ray/rllib/models/multiagentfcnet.py delete mode 100644 python/ray/rllib/utils/reshaper.py diff --git a/python/ray/rllib/examples/legacy_multiagent/__init__.py b/python/ray/rllib/examples/legacy_multiagent/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py deleted file mode 100644 index 9559648290da..000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar.py +++ /dev/null @@ -1,59 +0,0 @@ -""" Multiagent mountain car. Each agent outputs an action which -is summed to form the total action. This is a discrete -multiagent example -""" - -import gym -from gym.envs.registration import register - -import ray -import ray.rllib.agents.ppo as ppo -from ray.tune.registry import register_env - -env_name = "MultiAgentMountainCarEnv" - -env_version_num = 0 -env_name = env_name + '-v' + str(env_version_num) - - -def pass_params_to_gym(env_name): - global env_version_num - - register( - id=env_name, - entry_point=( - "ray.rllib.examples.legacy_multiagent.multiagent_mountaincar_env:" - "MultiAgentMountainCarEnv"), - max_episode_steps=200, - kwargs={}) - - -def create_env(env_config): - pass_params_to_gym(env_name) - env = gym.envs.make(env_name) - return env - - -if __name__ == '__main__': - register_env(env_name, lambda env_config: create_env(env_config)) - config = ppo.DEFAULT_CONFIG.copy() - horizon = 10 - num_cpus = 4 - ray.init(num_cpus=num_cpus, redirect_output=True) - config["num_workers"] = num_cpus - config["train_batch_size"] = 1000 - config["num_sgd_iter"] = 10 - config["gamma"] = 0.999 - config["horizon"] = horizon - config["use_gae"] = False - config["model"].update({"fcnet_hiddens": [256, 256]}) - options = { - "multiagent_obs_shapes": [2, 2], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": False, - "multiagent_fcnet_hiddens": [[32, 32]] * 2 - } - config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, config=config) - for i in range(1): - alg.train() diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py deleted file mode 100644 index c120f00c99ec..000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_mountaincar_env.py +++ /dev/null @@ -1,51 +0,0 @@ -from math import cos -from gym.spaces import Box, Tuple, Discrete -import numpy as np -from gym.envs.classic_control.mountain_car import MountainCarEnv -""" -Multiagent mountain car that sums and then -averages its actions to produce the velocity -""" - - -class MultiAgentMountainCarEnv(MountainCarEnv): - def __init__(self): - self.min_position = -1.2 - self.max_position = 0.6 - self.max_speed = 0.07 - self.goal_position = 0.5 - - self.low = np.array([self.min_position, -self.max_speed]) - self.high = np.array([self.max_position, self.max_speed]) - - self.viewer = None - - self.action_space = [Discrete(3) for _ in range(2)] - self.observation_space = Tuple( - [Box(self.low, self.high, dtype=np.float32) for _ in range(2)]) - - self.seed() - self.reset() - - def step(self, action): - summed_act = 0.5 * np.sum(action) - - position, velocity = self.state - velocity += (summed_act - 1) * 0.001 - velocity += cos(3 * position) * (-0.0025) - velocity = np.clip(velocity, -self.max_speed, self.max_speed) - position += velocity - position = np.clip(position, self.min_position, self.max_position) - if (position == self.min_position and velocity < 0): - velocity = 0 - - done = bool(position >= self.goal_position) - - reward = position - - self.state = (position, velocity) - return [np.array(self.state) for _ in range(2)], reward, done, {} - - def reset(self): - self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) - return [np.array(self.state) for _ in range(2)] diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py deleted file mode 100644 index b183ff2c0b15..000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py +++ /dev/null @@ -1,60 +0,0 @@ -""" Run script for multiagent pendulum env. Each agent outputs a -torque which is summed to form the total torque. This is a -continuous multiagent example -""" - -import gym -from gym.envs.registration import register - -import ray -import ray.rllib.agents.ppo as ppo -from ray.tune.registry import register_env - -env_name = "MultiAgentPendulumEnv" - -env_version_num = 0 -env_name = env_name + '-v' + str(env_version_num) - - -def pass_params_to_gym(env_name): - global env_version_num - - register( - id=env_name, - entry_point=( - "ray.rllib.examples.legacy_multiagent.multiagent_pendulum_env:" - "MultiAgentPendulumEnv"), - max_episode_steps=100, - kwargs={}) - - -def create_env(env_config): - pass_params_to_gym(env_name) - env = gym.envs.make(env_name) - return env - - -if __name__ == '__main__': - register_env(env_name, lambda env_config: create_env(env_config)) - config = ppo.DEFAULT_CONFIG.copy() - horizon = 10 - num_cpus = 4 - ray.init(num_cpus=num_cpus, redirect_output=True) - config["num_workers"] = num_cpus - config["train_batch_size"] = 1000 - config["sgd_minibatch_size"] = 10 - config["num_sgd_iter"] = 10 - config["gamma"] = 0.999 - config["horizon"] = horizon - config["use_gae"] = True - config["model"].update({"fcnet_hiddens": [256, 256]}) - options = { - "multiagent_obs_shapes": [3, 3], - "multiagent_act_shapes": [1, 1], - "multiagent_shared_model": True, - "multiagent_fcnet_hiddens": [[32, 32]] * 2 - } - config["model"].update({"custom_options": options}) - alg = ppo.PPOAgent(env=env_name, config=config) - for i in range(1): - alg.train() diff --git a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py b/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py deleted file mode 100644 index 02645832729f..000000000000 --- a/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum_env.py +++ /dev/null @@ -1,74 +0,0 @@ -from gym.spaces import Box, Tuple -from gym.utils import seeding -from gym.envs.classic_control.pendulum import PendulumEnv -import numpy as np -""" - Multiagent pendulum that sums its torques to generate an action -""" - - -class MultiAgentPendulumEnv(PendulumEnv): - metadata = { - 'render.modes': ['human', 'rgb_array'], - 'video.frames_per_second': 30 - } - - def __init__(self): - self.max_speed = 8 - self.max_torque = 2. - self.dt = .05 - self.viewer = None - - high = np.array([1., 1., self.max_speed]) - self.action_space = [ - Box(low=-self.max_torque / 2, - high=self.max_torque / 2, - shape=(1, ), - dtype=np.float32) for _ in range(2) - ] - self.observation_space = Tuple( - [Box(low=-high, high=high, dtype=np.float32) for _ in range(2)]) - - self.seed() - - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - - def step(self, u): - th, thdot = self.state # th := theta - - summed_u = np.sum(u) - g = 10. - m = 1. - length = 1. - dt = self.dt - - summed_u = np.clip(summed_u, -self.max_torque, self.max_torque) - self.last_u = summed_u # for rendering - costs = self.angle_normalize(th) ** 2 + .1 * thdot ** 2 + \ - .001 * (summed_u ** 2) - - newthdot = thdot + (-3 * g / (2 * length) * np.sin(th + np.pi) + 3. / - (m * length**2) * summed_u) * dt - newth = th + newthdot * dt - newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) - - self.state = np.array([newth, newthdot]) - return self._get_obs(), -costs, False, {} - - def reset(self): - high = np.array([np.pi, 1]) - self.state = self.np_random.uniform(low=-high, high=high) - self.last_u = None - return self._get_obs() - - def _get_obs(self): - theta, thetadot = self.state - return [ - np.array([np.cos(theta), np.sin(theta), thetadot]) - for _ in range(2) - ] - - def angle_normalize(self, x): - return (((x + np.pi) % (2 * np.pi)) - np.pi) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 9a889058cc6c..370429c43f3c 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -17,7 +17,6 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM -from ray.rllib.models.multiagentfcnet import MultiAgentFullyConnectedNetwork MODEL_CONFIGS = [ # === Built-in options === @@ -178,13 +177,6 @@ def _get_model(inputs, num_outputs, options, state_in, seq_lens): obs_rank = len(inputs.shape) - 1 - # num_outputs > 1 used to avoid hitting this with the value function - if isinstance( - options.get("custom_options", {}).get( - "multiagent_fcnet_hiddens", 1), list) and num_outputs > 1: - return MultiAgentFullyConnectedNetwork(inputs, num_outputs, - options) - if obs_rank > 1: return VisionNetwork(inputs, num_outputs, options) diff --git a/python/ray/rllib/models/multiagentfcnet.py b/python/ray/rllib/models/multiagentfcnet.py deleted file mode 100644 index dad7f2983103..000000000000 --- a/python/ray/rllib/models/multiagentfcnet.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from ray.rllib.models.model import Model -from ray.rllib.models.fcnet import FullyConnectedNetwork -from ray.rllib.utils.reshaper import Reshaper - - -class MultiAgentFullyConnectedNetwork(Model): - """Multiagent fully connected network.""" - - def _build_layers(self, inputs, num_outputs, options): - # Split the input and output tensors - input_shapes = options["custom_options"]["multiagent_obs_shapes"] - output_shapes = options["custom_options"]["multiagent_act_shapes"] - input_reshaper = Reshaper(input_shapes) - output_reshaper = Reshaper(output_shapes) - split_inputs = input_reshaper.split_tensor(inputs) - num_actions = output_reshaper.split_number(num_outputs) - - custom_options = options["custom_options"] - hiddens = custom_options.get("multiagent_fcnet_hiddens", - [[256, 256]] * 1) - - # check for a shared model - shared_model = custom_options.get("multiagent_shared_model", 0) - reuse = tf.AUTO_REUSE if shared_model else False - outputs = [] - for i in range(len(hiddens)): - scope = "multi" if shared_model else "multi{}".format(i) - with tf.variable_scope(scope, reuse=reuse): - sub_options = options.copy() - sub_options.update({"fcnet_hiddens": hiddens[i]}) - # TODO(ev) make this support arbitrary networks - fcnet = FullyConnectedNetwork(split_inputs[i], - int(num_actions[i]), sub_options) - output = fcnet.outputs - outputs.append(output) - overall_output = tf.concat(outputs, axis=1) - return overall_output, outputs diff --git a/python/ray/rllib/utils/reshaper.py b/python/ray/rllib/utils/reshaper.py deleted file mode 100644 index e9c16521210c..000000000000 --- a/python/ray/rllib/utils/reshaper.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np -import tensorflow as tf - - -class Reshaper(object): - """ - This class keeps track of where in the flattened observation space - we should be slicing and what the new shapes should be - """ - - def __init__(self, env_space): - self.shapes = [] - self.slice_positions = [] - self.env_space = env_space - if isinstance(env_space, list): - for space in env_space: - # Handle both gym arrays and just lists of inputs length - if hasattr(space, "n"): - arr_shape = np.asarray([1]) # discrete space - elif hasattr(space, "shape"): - arr_shape = np.asarray(space.shape) - else: - arr_shape = space - self.shapes.append(arr_shape) - if len(self.slice_positions) == 0: - self.slice_positions.append(np.product(arr_shape)) - else: - self.slice_positions.append( - np.product(arr_shape) + self.slice_positions[-1]) - else: - self.shapes.append(np.asarray(env_space.shape)) - self.slice_positions.append(np.product(env_space.shape)) - - def get_slice_lengths(self): - diffed_list = np.diff(self.slice_positions).tolist() - diffed_list.insert(0, self.slice_positions[0]) - return np.asarray(diffed_list).astype(int) - - def split_tensor(self, tensor, axis=-1): - # FIXME (ev) This won't work for mixed action distributions like - # one agent Gaussian one agent discrete - slice_rescale = int(tensor.shape.as_list()[axis] / int( - np.sum(self.get_slice_lengths()))) - return tf.split( - tensor, slice_rescale * self.get_slice_lengths(), axis=axis) - - def split_number(self, number): - slice_rescale = int(number / int(np.sum(self.get_slice_lengths()))) - return slice_rescale * self.get_slice_lengths() From 3ce8eb2d4cb7377aca74865a4064b0327b363099 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 2 Oct 2018 00:08:47 -0700 Subject: [PATCH 006/215] Test dying_worker_get and dying_worker_wait for xray. (#2997) This tests the case in which a worker is blocked in a call to ray.get or ray.wait, and then the worker dies. Then later, the object that the worker was waiting for becomes available. We need to make sure not to try to send a message to the dead worker and then die. Related to #2790. --- cmake/Modules/ArrowExternalProject.cmake | 4 +- python/ray/test/test_utils.py | 38 +++++ src/ray/common/client_connection.h | 2 +- src/ray/raylet/node_manager.cc | 76 +++++---- src/ray/raylet/node_manager.h | 19 ++- test/component_failures_test.py | 204 ++++++++++++++++++++++- test/multi_node_test.py | 42 +---- 7 files changed, 304 insertions(+), 81 deletions(-) diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index dfb25f244f9a..827673c35b28 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -14,10 +14,10 @@ # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2522. We +# The PR for this commit is https://github.com/apache/arrow/pull/2664. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 7104d64ff2cd6c20e29d3cf4ec5c58bc10798f66) +set(arrow_TAG 3545186d6997b943ffc3d79634f2d08eefbd7322) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index a29daa5a073e..6e18cd4393f6 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -6,6 +6,7 @@ import os import redis import subprocess +import sys import tempfile import time @@ -147,3 +148,40 @@ def run_and_get_output(command): with open(tmp.name, 'r') as f: result = f.readlines() return "\n".join(result) + + +def run_string_as_driver(driver_script): + """Run a driver as a separate process. + + Args: + driver_script: A string to run as a Python script. + + Returns: + The script's output. + """ + # Save the driver script as a file so we can call it using subprocess. + with tempfile.NamedTemporaryFile() as f: + f.write(driver_script.encode("ascii")) + f.flush() + out = ray.utils.decode( + subprocess.check_output([sys.executable, f.name])) + return out + + +def run_string_as_driver_nonblocking(driver_script): + """Start a driver as a separate process and return immediately. + + Args: + driver_script: A string to run as a Python script. + + Returns: + A handle to the driver process. + """ + # Save the driver script as a file so we can call it using subprocess. We + # do not delete this file because if we do then it may get removed before + # the Python process tries to run it. + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(driver_script.encode("ascii")) + f.flush() + return subprocess.Popen( + [sys.executable, f.name], stdout=subprocess.PIPE) diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 39d3c084d80f..20b232c333f0 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -114,7 +114,7 @@ class ClientConnection : public ServerConnection, MessageHandler message_handler_; /// A label used for debug messages. const std::string debug_label_; - /// Buffers for the current message being read rom the client. + /// Buffers for the current message being read from the client. int64_t read_version_; int64_t read_type_; uint64_t read_length_; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index bed03ad902c3..fcc60e030082 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -222,7 +222,8 @@ void NodeManager::KillWorker(std::shared_ptr worker) { retry_timer->expires_from_now(retry_duration); retry_timer->async_wait([retry_timer, worker](const boost::system::error_code &error) { RAY_LOG(DEBUG) << "Send SIGKILL to worker, pid=" << worker->Pid(); - // Force kill worker. + // Force kill worker. TODO(rkn): Is there some small danger that the worker + // has already died and the PID has been reassigned to a different process? kill(worker->Pid(), SIGKILL); }); } @@ -638,8 +639,25 @@ void NodeManager::ProcessGetTaskMessage( void NodeManager::ProcessDisconnectClientMessage( const std::shared_ptr &client) { - // Remove the dead worker from the pool and stop listening for messages. const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); + // This client can't be a worker and a driver. + RAY_CHECK(worker == nullptr || driver == nullptr); + + // If both worker and driver are null, then this method has already been + // called, so just return. + if (worker == nullptr && driver == nullptr) { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + return; + } + + // If the client is blocked, we need to treat it as unblocked. In particular, + // we are no longer waiting for its dependencies. If the client is not + // blocked, this won't do anything. + HandleClientUnblocked(client); + + // Remove the dead client from the pool and stop listening for messages. if (worker) { // The client is a worker. Handle the case where the worker is killed @@ -651,17 +669,15 @@ void NodeManager::ProcessDisconnectClientMessage( // If the worker was killed intentionally, e.g., when the driver that created // the task that this worker is currently executing exits, the task for this // worker has already been removed from queue, so the following are skipped. - auto const &running_tasks = local_queues_.GetRunningTasks(); - // TODO(rkn): This is too heavyweight just to get the task's driver ID. - auto const it = std::find_if( - running_tasks.begin(), running_tasks.end(), [task_id](const Task &task) { - return task.GetTaskSpecification().TaskId() == task_id; - }); - RAY_CHECK(running_tasks.size() != 0); - RAY_CHECK(it != running_tasks.end()); - const TaskSpecification &spec = it->GetTaskSpecification(); - const JobID job_id = spec.DriverId(); + task_dependency_manager_.TaskCanceled(task_id); + // task_dependency_manager_.UnsubscribeDependencies(current_task_id); + const Task &task = local_queues_.RemoveTask(task_id); + const TaskSpecification &spec = task.GetTaskSpecification(); + // Handle the task failure in order to raise an exception in the + // application. + TreatTaskAsFailed(spec); + const JobID &job_id = worker->GetAssignedDriverId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; @@ -669,18 +685,12 @@ void NodeManager::ProcessDisconnectClientMessage( << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( job_id, type, error_message.str(), current_time_ms())); - - // Handle the task failure in order to raise an exception in the - // application. - TreatTaskAsFailed(spec); - task_dependency_manager_.TaskCanceled(spec.TaskId()); - local_queues_.RemoveTask(spec.TaskId()); } worker_pool_.DisconnectWorker(worker); // If the worker was an actor, add it to the list of dead actors. - const ActorID actor_id = worker->GetActorId(); + const ActorID &actor_id = worker->GetActorId(); if (!actor_id.is_nil()) { // TODO(rkn): Consider broadcasting a message to all of the other // node managers so that they can mark the actor as dead. @@ -715,7 +725,6 @@ void NodeManager::ProcessDisconnectClientMessage( // The client is a driver. RAY_CHECK_OK(gcs_client_->driver_table().AppendDriverData(client->GetClientID(), /*is_dead=*/true)); - const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); RAY_CHECK(driver); auto driver_id = driver->GetAssignedTaskId(); RAY_CHECK(!driver_id.is_nil()); @@ -725,6 +734,10 @@ void NodeManager::ProcessDisconnectClientMessage( RAY_LOG(DEBUG) << "Driver (pid=" << driver->Pid() << ") is disconnected. " << "driver_id: " << driver->GetAssignedDriverId(); } + + // TODO(rkn): Tell the object manager that this client has disconnected so + // that it can clean up the wait requests for this client. Currently I think + // these can be leaked. } void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { @@ -798,12 +811,21 @@ void NodeManager::ProcessWaitRequestMessage( flatbuffers::Offset wait_reply = protocol::CreateWaitReply( fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); fbb.Finish(wait_reply); - RAY_CHECK_OK( + + auto status = client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer())); - // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleClientUnblocked(client); + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleClientUnblocked(client); + } + } else { + // We failed to write to the client, so disconnect the client. + RAY_LOG(WARNING) + << "Failed to send WaitReply to client, so disconnecting client"; + // We failed to send the reply to the client, so disconnect the worker. + ProcessDisconnectClientMessage(client); } }); RAY_CHECK_OK(status); @@ -1308,9 +1330,7 @@ void NodeManager::AssignTask(Task &task) { } else { RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; // We failed to send the task to the worker, so disconnect the worker. - ProcessClientMessage(worker->Connection(), - static_cast(protocol::MessageType::DisconnectClient), - nullptr); + ProcessDisconnectClientMessage(worker->Connection()); // Queue this task for future assignment. The task will be assigned to a // worker once one becomes available. // (See design_docs/task_states.rst for the state transition diagram.) diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index d8025a6c8520..b6b23223c797 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -267,7 +267,7 @@ class NodeManager { bool CheckDependencyManagerInvariant() const; /// Process client message of RegisterClientRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -275,26 +275,29 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of GetTask - // + /// /// \param client The client that sent the message. /// \return Void. void ProcessGetTaskMessage(const std::shared_ptr &client); - /// Process client message of DisconnectClient - // + /// Handle a client that has disconnected. This can be called multiple times + /// on the same client because this is triggered both when a client + /// disconnects and when the node manager fails to write a message to the + /// client. + /// /// \param client The client that sent the message. /// \return Void. void ProcessDisconnectClientMessage( const std::shared_ptr &client); /// Process client message of SubmitTask - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessSubmitTaskMessage(const uint8_t *message_data); /// Process client message of ReconstructObjects - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -302,7 +305,7 @@ class NodeManager { const std::shared_ptr &client, const uint8_t *message_data); /// Process client message of WaitRequest - // + /// /// \param client The client that sent the message. /// \param message_data A pointer to the message data. /// \return Void. @@ -310,7 +313,7 @@ class NodeManager { const uint8_t *message_data); /// Process client message of PushErrorRequest - // + /// /// \param message_data A pointer to the message data. /// \return Void. void ProcessPushErrorRequestMessage(const uint8_t *message_data); diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 64dd3712b2e7..3a57452e6115 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -2,11 +2,14 @@ from __future__ import division from __future__ import print_function -import pytest import os -import ray +import signal import time +import pytest + +import ray +from ray.test.test_utils import run_string_as_driver_nonblocking import pyarrow as pa @@ -23,6 +26,112 @@ def ray_start_workers_separate(): ray.shutdown() +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +# This test checks that when a worker dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_get_raylet(shutdown_only): + # Start the Ray processes. + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_worker_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + time.sleep(0.01) # Try to wait for the sleep task to get scheduled. + # Get the PID of the other worker. + worker_pid = ray.get(get_worker_pid.remote()) + + @ray.remote + def f(id_in_a_list): + ray.get(id_in_a_list[0]) + + # Have the worker wait in a get call. + result_id = f.remote([x_id]) + time.sleep(1) + + # Make sure the task hasn't finished. + ready_ids, _ = ray.wait([result_id], timeout=0) + assert len(ready_ids) == 0 + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Make sure the sleep task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a get, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_get(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.get(ray.ObjectID({})) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # get has been fulfilled. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a get, the # plasma store and manager will not die. @pytest.mark.skipif( @@ -59,6 +168,97 @@ def f(): exclude=[ray.services.PROCESS_TYPE_WORKER]) +# This test checks that when a worker dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_worker_wait_raylet(shutdown_only): + ray.init(num_cpus=2) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + @ray.remote + def get_pid(): + return os.getpid() + + x_id = sleep_forever.remote() + # Get the PID of the worker that block_in_wait will run on (sleep a little + # to make sure that sleep_forever has already started). + time.sleep(0.1) + worker_pid = ray.get(get_pid.remote()) + + @ray.remote + def block_in_wait(object_id_in_list): + ray.wait(object_id_in_list) + + # Have the worker wait in a wait call. + block_in_wait.remote([x_id]) + time.sleep(0.1) + + # Kill the worker. + os.kill(worker_pid, signal.SIGKILL) + time.sleep(0.1) + + # Create the object. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + +# This test checks that when a driver dies in the middle of a wait, the plasma +# store and raylet will not die. +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +@pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="Not working with new GCS API.") +def test_dying_driver_wait(shutdown_only): + # Start the Ray processes. + address_info = ray.init(num_cpus=1) + + @ray.remote + def sleep_forever(): + time.sleep(10**6) + + x_id = sleep_forever.remote() + + driver = """ +import ray +ray.init("{}") +ray.wait([ray.ObjectID({})]) +""".format(address_info["redis_address"], x_id.id()) + + p = run_string_as_driver_nonblocking(driver) + # Make sure the driver is running. + time.sleep(1) + assert p.poll() is None + + # Kill the driver process. + p.kill() + p.wait() + time.sleep(0.1) + + # Make sure the original task hasn't finished. + ready_ids, _ = ray.wait([x_id], timeout=0) + assert len(ready_ids) == 0 + # Seal the object so the store attempts to notify the worker that the + # wait can return. + ray.worker.global_worker.put_object(x_id, 1) + time.sleep(0.1) + + # Make sure that nothing has died. + assert ray.services.all_processes_alive() + + # This test checks that when a worker dies in the middle of a wait, the # plasma store and manager will not die. @pytest.mark.skipif( diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 657c03710962..a1f0bd87be29 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -5,49 +5,11 @@ import os import pytest import subprocess -import sys -import tempfile import time import ray -from ray.test.test_utils import run_and_get_output - - -def run_string_as_driver(driver_script): - """Run a driver as a separate process. - - Args: - driver_script: A string to run as a Python script. - - Returns: - The script's output. - """ - # Save the driver script as a file so we can call it using subprocess. - with tempfile.NamedTemporaryFile() as f: - f.write(driver_script.encode("ascii")) - f.flush() - out = ray.utils.decode( - subprocess.check_output([sys.executable, f.name])) - return out - - -def run_string_as_driver_nonblocking(driver_script): - """Start a driver as a separate process and return immediately. - - Args: - driver_script: A string to run as a Python script. - - Returns: - A handle to the driver process. - """ - # Save the driver script as a file so we can call it using subprocess. We - # do not delete this file because if we do then it may get removed before - # the Python process tries to run it. - with tempfile.NamedTemporaryFile(delete=False) as f: - f.write(driver_script.encode("ascii")) - f.flush() - return subprocess.Popen( - [sys.executable, f.name], stdout=subprocess.PIPE) +from ray.test.test_utils import (run_and_get_output, run_string_as_driver, + run_string_as_driver_nonblocking) @pytest.fixture From 9c606ea06cdd604f7d840e0f69d6d9b5d6b46394 Mon Sep 17 00:00:00 2001 From: bibabolynn <1018527906@qq.com> Date: Wed, 3 Oct 2018 13:53:54 +0800 Subject: [PATCH 007/215] fix bug: (#3000) before fix,RAY_FUN_CACHE use only get method ,can only get null fix : put after create --- .../java/org/ray/runtime/functionmanager/FunctionManager.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index e586741641ae..473a1f033203 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -28,7 +28,7 @@ public class FunctionManager { * Cache from a RayFunc object to its corresponding FunctionDescriptor. Because * `LambdaUtils.getSerializedLambda` is expensive. */ - private static final ThreadLocal, FunctionDescriptor>> + private static final ThreadLocal, FunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); /** @@ -51,6 +51,7 @@ public RayFunction getFunction(UniqueId driverId, RayFunc func) { final String methodName = serializedLambda.getImplMethodName(); final String typeDescriptor = serializedLambda.getImplMethodSignature(); functionDescriptor = new FunctionDescriptor(className, methodName, typeDescriptor); + RAY_FUNC_CACHE.get().put(func.getClass(),functionDescriptor); } return getFunction(driverId, functionDescriptor); } From cc7e2ecdd5499bfd5ba93205f17c2a0783578533 Mon Sep 17 00:00:00 2001 From: Si-Yuan Date: Wed, 3 Oct 2018 10:03:53 -0700 Subject: [PATCH 008/215] Change logfile names and also allow plasma store socket to be passed in. (#2862) --- .travis.yml | 6 + doc/source/index.rst | 1 + doc/source/tempfile.rst | 87 ++++++ .../local_scheduler_services.py | 13 +- python/ray/plasma/plasma.py | 16 +- python/ray/ray_constants.py | 4 +- python/ray/scripts/scripts.py | 25 +- python/ray/services.py | 204 ++++++------ python/ray/tempfile_services.py | 292 ++++++++++++++++++ python/ray/worker.py | 50 ++- python/ray/workers/default_worker.py | 10 + test/stress_tests.py | 9 +- test/tempfile_test.py | 119 +++++++ 13 files changed, 696 insertions(+), 140 deletions(-) create mode 100644 doc/source/tempfile.rst create mode 100644 python/ray/tempfile_services.py create mode 100644 test/tempfile_test.py diff --git a/.travis.yml b/.travis.yml index 47bef360e51e..35743b764ea1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -165,6 +165,9 @@ matrix: - python -m pytest -v python/ray/rllib/test/test_optimizers.py - python -m pytest -v python/ray/rllib/test/test_evaluators.py + # ray temp file tests + - python -m pytest -v test/tempfile_test.py + install: - ./.travis/install-dependencies.sh @@ -237,6 +240,9 @@ script: - python -m pytest -v python/ray/rllib/test/test_optimizers.py - python -m pytest -v python/ray/rllib/test/test_evaluators.py + # ray temp file tests + - python -m pytest -v test/tempfile_test.py + deploy: - provider: s3 access_key_id: AKIAJ2L7XDUSZVTXI5QA diff --git a/doc/source/index.rst b/doc/source/index.rst index d951066e8842..d8870bbaa054 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -118,6 +118,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin plasma-object-store.rst resources.rst redis-memory-management.rst + tempfile.rst .. toctree:: :maxdepth: 1 diff --git a/doc/source/tempfile.rst b/doc/source/tempfile.rst new file mode 100644 index 000000000000..7e489348c4bf --- /dev/null +++ b/doc/source/tempfile.rst @@ -0,0 +1,87 @@ +Temporary Files +=============== + +Ray will produce some temporary files during running. +They are useful for logging, debugging & sharing object store with other programs. + +Location of Temporary Files +--------------------------- + +First we introduce the concept of a session of Ray. + +A session contains a set of processes. A session is created by executing +``ray start`` command or call ``ray.init()`` in a Python script and ended by +executing ``ray stop`` or call ``ray.shutdown()``. + +For each session, Ray will create a *root temporary directory* to place all its +temporary files. The path is ``/tmp/ray/session_{datetime}_{pid}`` by default. +The pid belongs to the startup process (the process calling ``ray.init()`` or +the Ray process executed by a shell in ``ray start``). +You can sort by their names to find the latest session. + +You are allowed to change the *root temporary directory* in one of these ways: + +* Pass ``--temp-dir={your temp path}`` to ``ray start`` +* Specify ``temp_dir`` when call ``ray.init()`` + +You can also use ``default_worker.py --temp-dir={your temp path}`` to +start a new worker with given *root temporary directory*. + +The *root temporary directory* you specified will be given as it is, +without pids or datetime attached. + +Layout of Temporary Files +------------------------- + +A typical layout of temporary files could look like this: + +.. code-block:: text + + /tmp + └── ray + └── session_{datetime}_{pid} + ├── logs + │   ├── log_monitor.err + │   ├── log_monitor.out + │   ├── monitor.err + │   ├── monitor.out + │   ├── plasma_manager_0.err # array of plasma managers' outputs + │   ├── plasma_manager_0.out + │   ├── plasma_store_0.err # array of plasma stores' outputs + │   ├── plasma_store_0.out + │   ├── raylet_0.err # array of raylets' outputs. Control it with `--no-redirect-worker-output` (in Ray's command line) or `redirect_worker_output` (in ray.init()) + │   ├── raylet_0.out + │   ├── redis-shard_0.err # array of redis shards' outputs + │   ├── redis-shard_0.out + │   ├── redis.err # redis + │   ├── redis.out + │   ├── webui.err # ipython notebook web ui + │   ├── webui.out + │   ├── worker-{worker_id}.err # redirected output of workers + │   ├── worker-{worker_id}.out + │   └── {other workers} + ├── ray_ui.ipynb # ipython notebook file + └── sockets # for logging + ├── plasma_store + └── raylet # this could be deleted by Ray's shutdown cleanup. + + +Plasma Object Store Socket +-------------------------- + +Plasma object store sockets can be used to share objects with other programs using Apache Arrow. + +You are allowed to specify the plasma object store socket in one of these ways: + +* Pass ``--plasma-store-socket-name={your socket path}`` to ``ray start`` +* Specify ``plasma_store_socket_name`` when call ``ray.init()`` + +The path you specified will be given as it is without being affected any other paths. + +Notes +----- + +Temporary file policies are defined in ``python/ray/tempfile_services.py``. + +Currently, we keep ``/tmp/ray`` as the default directory for temporary data files of RLlib as before. +It is not very reasonable and could be changed later. diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py index f7847ce551b0..c576014e25ce 100644 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ b/python/ray/local_scheduler/local_scheduler_services.py @@ -4,14 +4,12 @@ import multiprocessing import os -import random import subprocess import sys import time - -def random_name(): - return str(random.randint(0, 99999999)) +from ray.tempfile_services import (get_local_scheduler_socket_name, + get_temp_root) def start_local_scheduler(plasma_store_name, @@ -71,7 +69,7 @@ def start_local_scheduler(plasma_store_name, local_scheduler_executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../core/src/local_scheduler/local_scheduler") - local_scheduler_name = "/tmp/scheduler{}".format(random_name()) + local_scheduler_name = get_local_scheduler_socket_name() command = [ local_scheduler_executable, "-s", local_scheduler_name, "-p", plasma_store_name, "-h", node_ip_address, "-n", @@ -88,11 +86,12 @@ def start_local_scheduler(plasma_store_name, "--object-store-name={} " "--object-store-manager-name={} " "--local-scheduler-name={} " - "--redis-address={}".format( + "--redis-address={} " + "--temp-dir={}".format( sys.executable, worker_path, node_ip_address, plasma_store_name, plasma_manager_name, local_scheduler_name, - redis_address)) + redis_address, get_temp_root())) command += ["-w", start_worker_command] if redis_address is not None: command += ["-r", redis_address] diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 60870c2b2021..262aeebfb448 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -8,6 +8,9 @@ import sys import time +from ray.tempfile_services import (get_object_store_socket_name, + get_plasma_manager_socket_name) + __all__ = [ "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" ] @@ -17,17 +20,14 @@ DEFAULT_PLASMA_STORE_MEMORY = 10**9 -def random_name(): - return str(random.randint(0, 99999999)) - - def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, use_valgrind=False, use_profiler=False, stdout_file=None, stderr_file=None, plasma_directory=None, - huge_pages=False): + huge_pages=False, + socket_name=None): """Start a plasma store process. Args: @@ -43,6 +43,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, be created. huge_pages: a boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. + socket_name (str): If provided, it will specify the socket + name used by the plasma store. Return: A tuple of the name of the plasma store socket and the process ID of @@ -66,7 +68,7 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, plasma_store_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_store_server") - plasma_store_name = "/tmp/plasma_store{}".format(random_name()) + plasma_store_name = socket_name or get_object_store_socket_name() command = [ plasma_store_executable, "-s", plasma_store_name, "-m", str(plasma_store_memory) @@ -136,7 +138,7 @@ def start_plasma_manager(store_name, plasma_manager_executable = os.path.join( os.path.abspath(os.path.dirname(__file__)), "../core/src/plasma/plasma_manager") - plasma_manager_name = "/tmp/plasma_manager{}".format(random_name()) + plasma_manager_name = get_plasma_manager_socket_name() if plasma_manager_port is not None: if num_retries != 1: raise Exception("num_retries must be 1 if port is specified.") diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index a9e4519d4cf5..d62b57b5c1cf 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -5,7 +5,7 @@ import os -import ray +from ray.local_scheduler import ObjectID def env_integer(key, default): @@ -15,7 +15,7 @@ def env_integer(key, default): ID_SIZE = 20 -NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\xff") +NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff") # If a remote function or actor (or some other export) has serialized size # greater than this quantity, print an warning. diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 0826a1387aec..d3e9417c1a81 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -177,11 +177,24 @@ def cli(logging_level, logging_format): is_flag=True, default=False, help="do not redirect non-worker stdout and stderr to files") +@click.option( + "--plasma-store-socket-name", + default=None, + help="manually specify the socket name of the plasma store") +@click.option( + "--raylet-socket-name", + default=None, + help="manually specify the socket path of the raylet process") +@click.option( + "--temp-dir", + default=None, + help="manually specify the root temporary dir of the Ray process") def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_max_clients, redis_shard_ports, object_manager_port, object_store_memory, num_workers, num_cpus, num_gpus, resources, head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, - use_raylet, no_redirect_worker_output, no_redirect_output): + use_raylet, no_redirect_worker_output, no_redirect_output, + plasma_store_socket_name, raylet_socket_name, temp_dir): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -260,7 +273,10 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) logger.info(address_info) logger.info( "\nStarted Ray on this node. You can add additional nodes to " @@ -329,7 +345,10 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) logger.info(address_info) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" diff --git a/python/ray/services.py b/python/ray/services.py index 3a421437c566..ab632354f6a9 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -2,14 +2,12 @@ from __future__ import division from __future__ import print_function -import binascii import json import logging import multiprocessing import os import random import resource -import shutil import signal import socket import subprocess @@ -17,8 +15,6 @@ import threading import time from collections import OrderedDict, namedtuple -from datetime import datetime - import redis import pyarrow @@ -28,6 +24,14 @@ import ray.local_scheduler import ray.plasma +from ray.tempfile_services import ( + get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name, + get_temp_redis_config_path, get_temp_root, new_global_scheduler_log_file, + new_local_scheduler_log_file, new_log_monitor_log_file, + new_monitor_log_file, new_plasma_manager_log_file, + new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file, + new_webui_log_file, new_worker_log_file, set_temp_root) + PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_LOG_MONITOR = "log_monitor" PROCESS_TYPE_WORKER = "worker" @@ -120,10 +124,6 @@ def new_port(): return random.randint(10000, 65535) -def random_name(): - return str(random.randint(0, 99999999)) - - def kill_process(p): """Kill a process. @@ -456,8 +456,7 @@ def start_redis(node_ip_address, A tuple of the address for the primary Redis shard and a list of addresses for the remaining shards. """ - redis_stdout_file, redis_stderr_file = new_log_files( - "redis", redirect_output) + redis_stdout_file, redis_stderr_file = new_redis_log_file(redirect_output) if redis_shard_ports is None: redis_shard_ports = num_redis_shards * [None] @@ -517,8 +516,8 @@ def start_redis(node_ip_address, # prefixed by "redis-". redis_shards = [] for i in range(num_redis_shards): - redis_stdout_file, redis_stderr_file = new_log_files( - "redis-{}".format(i), redirect_output) + redis_stdout_file, redis_stderr_file = new_redis_log_file( + redirect_output, shard_number=i) if not use_credis: redis_shard_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -572,7 +571,7 @@ def _make_temp_redis_config(node_ip_address): node_ip_address: The IP address of this node. This should not be 127.0.0.1. """ - redis_config_name = "/tmp/redis_conf{}".format(random_name()) + redis_config_name = get_temp_redis_config_path() with open(redis_config_name, 'w') as f: # This allows redis clients on the same machine to connect using the # node's IP address as opposed to just 127.0.0.1. This is only relevant @@ -799,15 +798,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): then this process will be killed by services.cleanup() when the Python process that imported services exits. """ - new_env = os.environ.copy() - notebook_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb") - # We copy the notebook file so that the original doesn't get modified by - # the user. - random_ui_id = random.randint(0, 100000) - new_notebook_filepath = "/tmp/raylogs/ray_ui{}.ipynb".format(random_ui_id) - new_notebook_directory = os.path.dirname(new_notebook_filepath) - shutil.copy(notebook_filepath, new_notebook_filepath) + port = 8888 while True: try: @@ -821,7 +812,8 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): new_env["REDIS_ADDRESS"] = redis_address # We generate the token used for authentication ourselves to avoid # querying the jupyter server. - token = ray.utils.decode(binascii.hexlify(os.urandom(24))) + new_notebook_directory, webui_url, token = ( + get_ipython_notebook_path(port)) # The --ip=0.0.0.0 flag is intended to enable connecting to a notebook # running within a docker container (from the outside). command = [ @@ -847,8 +839,6 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): else: if cleanup: all_processes[PROCESS_TYPE_WEB_UI].append(ui_process) - webui_url = ("http://localhost:{}/notebooks/ray_ui{}.ipynb?token={}" - .format(port, random_ui_id, token)) logger.info("\n" + "=" * 70) logger.info("View the web UI at {}".format(webui_url)) logger.info("=" * 70 + "\n") @@ -971,6 +961,7 @@ def start_local_scheduler(redis_address, def start_raylet(redis_address, node_ip_address, + raylet_name, plasma_store_name, worker_path, resources=None, @@ -988,6 +979,7 @@ def start_raylet(redis_address, scheduler is running on. plasma_store_name (str): The name of the plasma store socket to connect to. + raylet_name (str): The name of the raylet socket to create. worker_path (str): The path of the script to use when the local scheduler starts up new workers. use_valgrind (bool): True if the raylet should be started inside @@ -1023,16 +1015,17 @@ def start_raylet(redis_address, ]) gcs_ip_address, gcs_port = redis_address.split(":") - raylet_name = "/tmp/raylet{}".format(random_name()) # Create the command that the Raylet will use to start workers. start_worker_command = ("{} {} " "--node-ip-address={} " "--object-store-name={} " "--raylet-name={} " - "--redis-address={}".format( + "--redis-address={} " + "--temp-dir={}".format( sys.executable, worker_path, node_ip_address, - plasma_store_name, raylet_name, redis_address)) + plasma_store_name, raylet_name, redis_address, + get_temp_root())) command = [ RAYLET_EXECUTABLE, @@ -1084,7 +1077,8 @@ def start_plasma_store(node_ip_address, cleanup=True, plasma_directory=None, huge_pages=False, - use_raylet=False): + use_raylet=False, + plasma_store_socket_name=None): """This method starts an object store process. Args: @@ -1158,7 +1152,8 @@ def start_plasma_store(node_ip_address, stdout_file=store_stdout_file, stderr_file=store_stderr_file, plasma_directory=plasma_directory, - huge_pages=huge_pages) + huge_pages=huge_pages, + socket_name=plasma_store_socket_name) # Start the plasma manager. if not use_raylet: if object_manager_port is not None: @@ -1235,7 +1230,8 @@ def start_worker(node_ip_address, "--object-store-name=" + object_store_name, "--object-store-manager-name=" + object_store_manager_name, "--local-scheduler-name=" + local_scheduler_name, - "--redis-address=" + str(redis_address) + "--redis-address=" + str(redis_address), + "--temp-dir=" + get_temp_root() ] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: @@ -1327,7 +1323,10 @@ def start_ray_processes(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False): + use_raylet=False, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None): """Helper method to start Ray processes. Args: @@ -1385,13 +1384,22 @@ def start_ray_processes(address_info=None, autoscaling_config: path to autoscaling config file. use_raylet: True if the new raylet code path should be used. This is not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. Returns: A dictionary of the address information for the processes that were started. """ - logger.info( - "Process STDOUT and STDERR is being redirected to /tmp/raylogs/.") + + set_temp_root(temp_dir) + + logger.info("Process STDOUT and STDERR is being redirected to {}.".format( + get_logs_dir_path())) if resources is None: resources = {} @@ -1438,8 +1446,8 @@ def start_ray_processes(address_info=None, time.sleep(0.1) # Start monitoring the processes. - monitor_stdout_file, monitor_stderr_file = new_log_files( - "monitor", redirect_output) + monitor_stdout_file, monitor_stderr_file = new_monitor_log_file( + redirect_output) start_monitor( redis_address, node_ip_address, @@ -1464,8 +1472,8 @@ def start_ray_processes(address_info=None, # Start the log monitor, if necessary. if include_log_monitor: - log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( - "log_monitor", redirect_output=True) + log_monitor_stdout_file, log_monitor_stderr_file = ( + new_log_monitor_log_file()) start_log_monitor( redis_address, node_ip_address, @@ -1476,7 +1484,7 @@ def start_ray_processes(address_info=None, # Start the global scheduler, if necessary. if include_global_scheduler and not use_raylet: global_scheduler_stdout_file, global_scheduler_stderr_file = ( - new_log_files("global_scheduler", redirect_output)) + new_global_scheduler_log_file(redirect_output)) start_global_scheduler( redis_address, node_ip_address, @@ -1505,10 +1513,14 @@ def start_ray_processes(address_info=None, # Start any object stores that do not yet exist. for i in range(num_local_schedulers - len(object_store_addresses)): # Start Plasma. - plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( - "plasma_store_{}".format(i), redirect_output) - plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( - "plasma_manager_{}".format(i), redirect_output) + plasma_store_stdout_file, plasma_store_stderr_file = ( + new_plasma_store_log_file(i, redirect_output)) + + # If we use raylet, plasma manager won't be started and we don't need + # to create temp files for them. + plasma_manager_stdout_file, plasma_manager_stderr_file = ( + new_plasma_manager_log_file(i, redirect_output and not use_raylet)) + object_store_address = start_plasma_store( node_ip_address, redis_address, @@ -1521,7 +1533,8 @@ def start_ray_processes(address_info=None, cleanup=cleanup, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name) object_store_addresses.append(object_store_address) time.sleep(0.1) @@ -1546,9 +1559,8 @@ def start_ray_processes(address_info=None, # redirect the worker output, then we cannot redirect the local # scheduler output. local_scheduler_stdout_file, local_scheduler_stderr_file = ( - new_log_files( - "local_scheduler_{}".format(i), - redirect_output=redirect_worker_output)) + new_local_scheduler_log_file( + i, redirect_output=redirect_worker_output)) local_scheduler_name = start_local_scheduler( redis_address, node_ip_address, @@ -1571,12 +1583,13 @@ def start_ray_processes(address_info=None, else: # Start any raylets that do not exist yet. for i in range(len(raylet_socket_names), num_local_schedulers): - raylet_stdout_file, raylet_stderr_file = new_log_files( - "raylet_{}".format(i), redirect_output=redirect_worker_output) + raylet_stdout_file, raylet_stderr_file = new_raylet_log_file( + i, redirect_output=redirect_worker_output) address_info["raylet_socket_names"].append( start_raylet( redis_address, node_ip_address, + raylet_socket_name or get_raylet_socket_name(), object_store_addresses[i].name, worker_path, resources=resources[i], @@ -1592,8 +1605,8 @@ def start_ray_processes(address_info=None, object_store_address = object_store_addresses[i] local_scheduler_name = local_scheduler_socket_names[i] for j in range(num_local_scheduler_workers): - worker_stdout_file, worker_stderr_file = new_log_files( - "worker_{}_{}".format(i, j), redirect_output) + worker_stdout_file, worker_stderr_file = new_worker_log_file( + i, j, redirect_output) start_worker( node_ip_address, object_store_address.name, @@ -1611,8 +1624,7 @@ def start_ray_processes(address_info=None, # Try to start the web UI. if include_webui: - ui_stdout_file, ui_stderr_file = new_log_files( - "webui", redirect_output=True) + ui_stdout_file, ui_stderr_file = new_webui_log_file() address_info["webui_url"] = start_ui( redis_address, stdout_file=ui_stdout_file, @@ -1637,7 +1649,10 @@ def start_ray_node(node_ip_address, resources=None, plasma_directory=None, huge_pages=False, - use_raylet=False): + use_raylet=False, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None): """Start the Ray processes for a single node. This assumes that the Ray processes on some master node have already been @@ -1672,6 +1687,12 @@ def start_ray_node(node_ip_address, Store with hugetlbfs support. Requires plasma_directory. use_raylet: True if the new raylet code path should be used. This is not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. Returns: A dictionary of the address information for the processes that were @@ -1695,7 +1716,10 @@ def start_ray_node(node_ip_address, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) def start_ray_head(address_info=None, @@ -1718,7 +1742,10 @@ def start_ray_head(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False): + use_raylet=False, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None): """Start Ray in local mode. Args: @@ -1770,6 +1797,12 @@ def start_ray_head(address_info=None, autoscaling_config: path to autoscaling config file. use_raylet: True if the new raylet code path should be used. This is not supported yet. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. Returns: A dictionary of the address information for the processes that were @@ -1799,58 +1832,7 @@ def start_ray_head(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet) - - -def try_to_create_directory(directory_path): - """Attempt to create a directory that is globally readable/writable. - - Args: - directory_path: The path of the directory to create. - """ - if not os.path.exists(directory_path): - try: - os.makedirs(directory_path) - except OSError as e: - if e.errno != os.errno.EEXIST: - raise e - logger.warning( - "Attempted to create '{}', but the directory already " - "exists.".format(directory_path)) - # Change the log directory permissions so others can use it. This is - # important when multiple people are using the same machine. - os.chmod(directory_path, 0o0777) - - -def new_log_files(name, redirect_output): - """Generate partially randomized filenames for log files. - - Args: - name (str): descriptive string for this log file. - redirect_output (bool): True if files should be generated for logging - stdout and stderr and false if stdout and stderr should not be - redirected. - - Returns: - If redirect_output is true, this will return a tuple of two - filehandles. The first is for redirecting stdout and the second is - for redirecting stderr. If redirect_output is false, this will - return a tuple of two None objects. - """ - if not redirect_output: - return None, None - - # Create a directory to be used for process log files. - logs_dir = "/tmp/raylogs" - try_to_create_directory(logs_dir) - # Create another directory that will be used by some of the RL algorithms. - try_to_create_directory("/tmp/ray") - - log_id = random.randint(0, 10000) - date_str = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - log_stdout = "{}/{}-{}-{:05d}.out".format(logs_dir, name, date_str, log_id) - log_stderr = "{}/{}-{}-{:05d}.err".format(logs_dir, name, date_str, log_id) - # Line-buffer the output (mode 1) - log_stdout_file = open(log_stdout, "a", buffering=1) - log_stderr_file = open(log_stderr, "a", buffering=1) - return log_stdout_file, log_stderr_file + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py new file mode 100644 index 000000000000..3b5adfa802db --- /dev/null +++ b/python/ray/tempfile_services.py @@ -0,0 +1,292 @@ +import binascii +import collections +import datetime +import errno +import logging +import os +import shutil +import tempfile + +import ray.utils + +logger = logging.getLogger(__name__) +_incremental_dict = collections.defaultdict(lambda: 0) +_temp_root = None + + +def make_inc_temp(suffix="", prefix="", directory_name="/tmp/ray"): + """Return a incremental temporary file name. The file is not created. + + Args: + suffix (str): The suffix of the temp file. + prefix (str): The prefix of the temp file. + directory_name (str) : The base directory of the temp file. + + Returns: + A string of file name. If there existing a file having the same name, + the returned name will look like + "{directory_name}/{prefix}.{unique_index}{suffix}" + """ + index = _incremental_dict[suffix, prefix, directory_name] + # `tempfile.TMP_MAX` could be extremely large, + # so using `range` in Python2.x should be avoided. + while index < tempfile.TMP_MAX: + if index == 0: + filename = os.path.join(directory_name, prefix + suffix) + else: + filename = os.path.join(directory_name, + prefix + "." + str(index) + suffix) + index += 1 + if not os.path.exists(filename): + _incremental_dict[suffix, prefix, + directory_name] = index # Save the index. + return filename + + raise FileExistsError(errno.EEXIST, "No usable temporary filename found") + + +def try_to_create_directory(directory_path): + """Attempt to create a directory that is globally readable/writable. + + Args: + directory_path: The path of the directory to create. + """ + if not os.path.exists(directory_path): + try: + os.makedirs(directory_path) + except OSError as e: + if e.errno != os.errno.EEXIST: + raise e + logger.warning( + "Attempted to create '{}', but the directory already " + "exists.".format(directory_path)) + # Change the log directory permissions so others can use it. This is + # important when multiple people are using the same machine. + os.chmod(directory_path, 0o0777) + + +def get_temp_root(): + """Get the path of the temporary root. If not existing, it will be created. + """ + global _temp_root + + date_str = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + + # Lazy creation. Avoid creating directories never used. + if _temp_root is None: + _temp_root = make_inc_temp( + prefix="session_{date_str}_{pid}".format( + pid=os.getpid(), date_str=date_str), + directory_name="/tmp/ray") + try_to_create_directory(_temp_root) + return _temp_root + + +def set_temp_root(path): + """Set the path of the temporary root. It will be created lazily.""" + global _temp_root + _temp_root = path + + +def get_logs_dir_path(): + """Get a temp dir for logging.""" + logs_dir = os.path.join(get_temp_root(), "logs") + try_to_create_directory(logs_dir) + return logs_dir + + +def get_sockets_dir_path(): + """Get a temp dir for sockets.""" + sockets_dir = os.path.join(get_temp_root(), "sockets") + try_to_create_directory(sockets_dir) + return sockets_dir + + +def get_raylet_socket_name(suffix=""): + """Get a socket name for raylet.""" + sockets_dir = get_sockets_dir_path() + + raylet_socket_name = make_inc_temp( + prefix="raylet", directory_name=sockets_dir, suffix=suffix) + return raylet_socket_name + + +def get_object_store_socket_name(): + """Get a socket name for plasma object store.""" + sockets_dir = get_sockets_dir_path() + return make_inc_temp(prefix="plasma_store", directory_name=sockets_dir) + + +def get_plasma_manager_socket_name(): + """Get a socket name for plasma manager.""" + sockets_dir = get_sockets_dir_path() + return make_inc_temp(prefix="plasma_manager", directory_name=sockets_dir) + + +def get_local_scheduler_socket_name(suffix=""): + """Get a socket name for local scheduler. + + This function could be unsafe. The socket name may + refer to a file that did not exist at some point, but by the time + you get around to creating it, someone else may have beaten you to + the punch. + """ + sockets_dir = get_sockets_dir_path() + raylet_socket_name = make_inc_temp( + prefix="scheduler", directory_name=sockets_dir, suffix=suffix) + + return raylet_socket_name + + +def get_ipython_notebook_path(port): + """Get a new ipython notebook path""" + + notebook_filepath = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb") + # We copy the notebook file so that the original doesn't get modified by + # the user. + notebook_name = make_inc_temp( + suffix=".ipynb", prefix="ray_ui", directory_name=get_temp_root()) + new_notebook_filepath = os.path.join(get_logs_dir_path(), notebook_name) + shutil.copy(notebook_filepath, new_notebook_filepath) + new_notebook_directory = os.path.dirname(new_notebook_filepath) + token = ray.utils.decode(binascii.hexlify(os.urandom(24))) + webui_url = ("http://localhost:{}/notebooks/{}?token={}".format( + port, os.path.basename(notebook_name), token)) + return new_notebook_directory, webui_url, token + + +def get_temp_redis_config_path(): + """Get a temp name of the redis config file.""" + redis_config_name = make_inc_temp( + prefix="redis_conf", directory_name=get_temp_root()) + return redis_config_name + + +def new_log_files(name, redirect_output): + """Generate partially randomized filenames for log files. + + Args: + name (str): descriptive string for this log file. + redirect_output (bool): True if files should be generated for logging + stdout and stderr and false if stdout and stderr should not be + redirected. + + Returns: + If redirect_output is true, this will return a tuple of two + filehandles. The first is for redirecting stdout and the second is + for redirecting stderr. If redirect_output is false, this will + return a tuple of two None objects. + """ + if not redirect_output: + return None, None + + # Create a directory to be used for process log files. + logs_dir = get_logs_dir_path() + # Create another directory that will be used by some of the RL algorithms. + + # TODO(suquark): This is done by the old code. + # We should be able to control its path later. + try_to_create_directory("/tmp/ray") + + log_stdout = make_inc_temp( + suffix=".out", prefix=name, directory_name=logs_dir) + log_stderr = make_inc_temp( + suffix=".err", prefix=name, directory_name=logs_dir) + # Line-buffer the output (mode 1) + log_stdout_file = open(log_stdout, "a", buffering=1) + log_stderr_file = open(log_stderr, "a", buffering=1) + return log_stdout_file, log_stderr_file + + +def new_redis_log_file(redirect_output, shard_number=None): + """Create new logging files for redis""" + if shard_number is None: + redis_stdout_file, redis_stderr_file = new_log_files( + "redis", redirect_output) + else: + redis_stdout_file, redis_stderr_file = new_log_files( + "redis-shard_{}".format(shard_number), redirect_output) + return redis_stdout_file, redis_stderr_file + + +def new_raylet_log_file(local_scheduler_index, redirect_output): + """Create new logging files for raylet.""" + raylet_stdout_file, raylet_stderr_file = new_log_files( + "raylet_{}".format(local_scheduler_index), + redirect_output=redirect_output) + return raylet_stdout_file, raylet_stderr_file + + +def new_local_scheduler_log_file(local_scheduler_index, redirect_output): + """Create new logging files for local scheduler. + + It is only used in non-raylet versions. + """ + local_scheduler_stdout_file, local_scheduler_stderr_file = (new_log_files( + "local_scheduler_{}".format(local_scheduler_index), + redirect_output=redirect_output)) + return local_scheduler_stdout_file, local_scheduler_stderr_file + + +def new_webui_log_file(): + """Create new logging files for web ui.""" + ui_stdout_file, ui_stderr_file = new_log_files( + "webui", redirect_output=True) + return ui_stdout_file, ui_stderr_file + + +def new_worker_log_file(local_scheduler_index, worker_index, redirect_output): + """Create new logging files for workers with local scheduler index. + + It is only used in non-raylet versions. + """ + worker_stdout_file, worker_stderr_file = new_log_files( + "worker_{}_{}".format(local_scheduler_index, worker_index), + redirect_output) + return worker_stdout_file, worker_stderr_file + + +def new_worker_redirected_log_file(worker_id): + """Create new logging files for workers to redirect its output.""" + worker_stdout_file, worker_stderr_file = (new_log_files( + "worker-" + ray.utils.binary_to_hex(worker_id), True)) + return worker_stdout_file, worker_stderr_file + + +def new_log_monitor_log_file(): + """Create new logging files for the log monitor.""" + log_monitor_stdout_file, log_monitor_stderr_file = new_log_files( + "log_monitor", redirect_output=True) + return log_monitor_stdout_file, log_monitor_stderr_file + + +def new_global_scheduler_log_file(redirect_output): + """Create new logging files for the new global scheduler. + + It is only used in non-raylet versions. + """ + global_scheduler_stdout_file, global_scheduler_stderr_file = ( + new_log_files("global_scheduler", redirect_output)) + return global_scheduler_stdout_file, global_scheduler_stderr_file + + +def new_plasma_store_log_file(local_scheduler_index, redirect_output): + """Create new logging files for the plasma store.""" + plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( + "plasma_store_{}".format(local_scheduler_index), redirect_output) + return plasma_store_stdout_file, plasma_store_stderr_file + + +def new_plasma_manager_log_file(local_scheduler_index, redirect_output): + """Create new logging files for the plasma manager.""" + plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( + "plasma_manager_{}".format(local_scheduler_index), redirect_output) + return plasma_manager_stdout_file, plasma_manager_stderr_file + + +def new_monitor_log_file(redirect_output): + """Create new logging files for the monitor.""" + monitor_stdout_file, monitor_stderr_file = new_log_files( + "monitor", redirect_output) + return monitor_stdout_file, monitor_stderr_file diff --git a/python/ray/worker.py b/python/ray/worker.py index 2d1d45f65b1c..c0714b1fc4b7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -27,6 +27,7 @@ import ray.serialization as serialization import ray.services as services import ray.signature +import ray.tempfile_services as tempfile_services import ray.local_scheduler import ray.plasma import ray.ray_constants as ray_constants @@ -1528,7 +1529,10 @@ def _init(address_info=None, plasma_directory=None, huge_pages=False, include_webui=True, - use_raylet=None): + use_raylet=None, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None): """Helper method to connect to an existing Ray cluster or start a new one. This method handles two cases. Either a Ray cluster already exists and we @@ -1584,6 +1588,12 @@ def _init(address_info=None, include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. use_raylet: True if the new raylet code path should be used. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. Returns: Address information about the started processes. @@ -1658,7 +1668,10 @@ def _init(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) else: if redis_address is None: raise Exception("When connecting to an existing cluster, " @@ -1690,6 +1703,15 @@ def _init(address_info=None, if huge_pages: raise Exception("When connecting to an existing cluster, " "huge_pages must not be provided.") + if temp_dir is not None: + raise Exception("When connecting to an existing cluster, " + "temp_dir must not be provided.") + if plasma_store_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "plasma_store_socket_name must not be provided.") + if raylet_socket_name is not None: + raise Exception("When connecting to an existing cluster, " + "raylet_socket_name must not be provided.") # Get the node IP address if one is not provided. if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) @@ -1719,6 +1741,9 @@ def _init(address_info=None, else: driver_address_info["raylet_socket_name"] = ( address_info["raylet_socket_names"][0]) + + # We only pass `temp_dir` to a worker (WORKER_MODE). + # It can't be a worker here. connect( driver_address_info, object_id_seed=object_id_seed, @@ -1750,7 +1775,10 @@ def init(redis_address=None, use_raylet=None, configure_logging=True, logging_level=logging.INFO, - logging_format=ray_constants.LOGGER_FORMAT): + logging_format=ray_constants.LOGGER_FORMAT, + plasma_store_socket_name=None, + raylet_socket_name=None, + temp_dir=None): """Connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -1815,6 +1843,12 @@ def init(redis_address=None, logging_level: Logging level, default will be loging.INFO. logging_format: Logging format, default will be "%(message)s" which means only contains the message. + plasma_store_socket_name (str): If provided, it will specify the socket + name used by the plasma store. + raylet_socket_name (str): If provided, it will specify the socket path + used by the raylet process. + temp_dir (str): If provided, it will specify the root temporary + directory for the Ray process. Returns: Address information about the started processes. @@ -1863,7 +1897,10 @@ def init(redis_address=None, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, - use_raylet=use_raylet) + use_raylet=use_raylet, + plasma_store_socket_name=plasma_store_socket_name, + raylet_socket_name=raylet_socket_name, + temp_dir=temp_dir) for hook in _post_init_hooks: hook() return ret @@ -2135,8 +2172,9 @@ def connect(info, else: redirect_worker_output = 0 if redirect_worker_output: - log_stdout_file, log_stderr_file = services.new_log_files( - "worker", True) + log_stdout_file, log_stderr_file = ( + tempfile_services.new_worker_redirected_log_file( + worker.worker_id)) sys.stdout = log_stdout_file sys.stderr = log_stderr_file services.record_log_files_in_redis( diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index cd7b3f4a45c3..72679722fa88 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -9,6 +9,7 @@ import ray import ray.actor import ray.ray_constants as ray_constants +import ray.tempfile_services as tempfile_services parser = argparse.ArgumentParser( description=("Parse addresses for the worker " @@ -53,6 +54,12 @@ type=str, default=ray_constants.LOGGER_FORMAT, help=ray_constants.LOGGER_FORMAT_HELP) +parser.add_argument( + "--temp-dir", + required=False, + type=str, + default=None, + help="Specify the path of the temporary directory use by Ray process.") if __name__ == "__main__": args = parser.parse_args() @@ -70,6 +77,9 @@ level=logging.getLevelName(args.logging_level.upper()), format=args.logging_format) + # Override the temporary directory. + tempfile_services.set_temp_root(args.temp_dir) + ray.worker.connect( info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None)) diff --git a/test/stress_tests.py b/test/stress_tests.py index 6cea02d82028..6fc7cc487e58 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -8,6 +8,7 @@ import time import ray +import ray.tempfile_services import ray.ray_constants as ray_constants @@ -173,10 +174,10 @@ def ray_start_reconstruction(request): plasma_addresses = [] objstore_memory = plasma_store_memory // num_local_schedulers for i in range(num_local_schedulers): - store_stdout_file, store_stderr_file = ray.services.new_log_files( - "plasma_store_{}".format(i), True) - manager_stdout_file, manager_stderr_file = (ray.services.new_log_files( - "plasma_manager_{}".format(i), True)) + store_stdout_file, store_stderr_file = ( + ray.tempfile_services.new_plasma_store_log_file(i, True)) + manager_stdout_file, manager_stderr_file = ( + ray.tempfile_services.new_plasma_manager_log_file(i, True)) plasma_addresses.append( ray.services.start_plasma_store( node_ip_address, diff --git a/test/tempfile_test.py b/test/tempfile_test.py new file mode 100644 index 000000000000..a174b621b549 --- /dev/null +++ b/test/tempfile_test.py @@ -0,0 +1,119 @@ +import os +import shutil +import time +import pytest +import ray +import ray.tempfile_services as tempfile_services + + +def test_conn_cluster(): + # plasma_store_socket_name + with pytest.raises(Exception) as exc_info: + ray.init( + use_raylet=True, + redis_address="127.0.0.1:6379", + plasma_store_socket_name="/tmp/this_should_fail") + assert exc_info.value.args[0] == ( + "When connecting to an existing cluster, " + "plasma_store_socket_name must not be provided.") + + # raylet_socket_name + with pytest.raises(Exception) as exc_info: + ray.init( + use_raylet=True, + redis_address="127.0.0.1:6379", + raylet_socket_name="/tmp/this_should_fail") + assert exc_info.value.args[0] == ( + "When connecting to an existing cluster, " + "raylet_socket_name must not be provided.") + + # temp_dir + with pytest.raises(Exception) as exc_info: + ray.init( + use_raylet=True, + redis_address="127.0.0.1:6379", + temp_dir="/tmp/this_should_fail") + assert exc_info.value.args[0] == ( + "When connecting to an existing cluster, " + "temp_dir must not be provided.") + + +def test_tempdir(): + ray.init(use_raylet=True, temp_dir="/tmp/i_am_a_temp_dir") + assert os.path.exists( + "/tmp/i_am_a_temp_dir"), "Specified temp dir not found." + ray.shutdown() + shutil.rmtree("/tmp/i_am_a_temp_dir", ignore_errors=True) + + +def test_raylet_socket_name(): + ray.init(use_raylet=True, raylet_socket_name="/tmp/i_am_a_temp_socket") + assert os.path.exists( + "/tmp/i_am_a_temp_socket"), "Specified socket path not found." + ray.shutdown() + try: + os.remove("/tmp/i_am_a_temp_socket") + except Exception: + pass + + +def test_temp_plasma_store_socket(): + ray.init( + use_raylet=True, plasma_store_socket_name="/tmp/i_am_a_temp_socket") + assert os.path.exists( + "/tmp/i_am_a_temp_socket"), "Specified socket path not found." + ray.shutdown() + try: + os.remove("/tmp/i_am_a_temp_socket") + except Exception: + pass + + +def test_raylet_tempfiles(): + ray.init(use_raylet=True, redirect_worker_output=False) + top_levels = set(os.listdir(tempfile_services.get_temp_root())) + assert top_levels == {"ray_ui.ipynb", "sockets", "logs"} + log_files = set(os.listdir(tempfile_services.get_logs_dir_path())) + assert log_files == { + "log_monitor.out", "log_monitor.err", "plasma_store_0.out", + "plasma_store_0.err", "webui.out", "webui.err", "monitor.out", + "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", + "redis.err" + } # without raylet logs + socket_files = set(os.listdir(tempfile_services.get_sockets_dir_path())) + assert socket_files == {"plasma_store", "raylet"} + ray.shutdown() + + ray.init(use_raylet=True, redirect_worker_output=True, num_workers=0) + top_levels = set(os.listdir(tempfile_services.get_temp_root())) + assert top_levels == {"ray_ui.ipynb", "sockets", "logs"} + log_files = set(os.listdir(tempfile_services.get_logs_dir_path())) + assert log_files == { + "log_monitor.out", "log_monitor.err", "plasma_store_0.out", + "plasma_store_0.err", "webui.out", "webui.err", "monitor.out", + "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", + "redis.err", "raylet_0.out", "raylet_0.err" + } # with raylet logs + socket_files = set(os.listdir(tempfile_services.get_sockets_dir_path())) + assert socket_files == {"plasma_store", "raylet"} + ray.shutdown() + + ray.init(use_raylet=True, redirect_worker_output=True, num_workers=2) + top_levels = set(os.listdir(tempfile_services.get_temp_root())) + assert top_levels == {"ray_ui.ipynb", "sockets", "logs"} + time.sleep(3) # wait workers to start + log_files = set(os.listdir(tempfile_services.get_logs_dir_path())) + assert log_files.issuperset({ + "log_monitor.out", "log_monitor.err", "plasma_store_0.out", + "plasma_store_0.err", "webui.out", "webui.err", "monitor.out", + "monitor.err", "redis-shard_0.out", "redis-shard_0.err", "redis.out", + "redis.err", "raylet_0.out", "raylet_0.err" + }) # with raylet logs + + # Check numbers of worker log file. + assert sum( + 1 for filename in log_files if filename.startswith("worker")) == 4 + + socket_files = set(os.listdir(tempfile_services.get_sockets_dir_path())) + assert socket_files == {"plasma_store", "raylet"} + ray.shutdown() From d73ee36e60bde79436282b08188fe6ca3ac41956 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 3 Oct 2018 13:43:40 -0700 Subject: [PATCH 009/215] Update links to use latest 0.5.3 wheels instead of 0.5.2. (#3018) --- doc/source/installation.rst | 16 ++++++++-------- python/ray/autoscaler/aws/example-full.yaml | 6 +++--- python/ray/autoscaler/gcp/example-full.yaml | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 4c4bc3f165ef..ebee27f9f028 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -30,14 +30,14 @@ features but may be subject to more bugs. To install these wheels, run the follo =================== =================== -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp34-cp34m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp34-cp34m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp34-cp34m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp34-cp34m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index 55691863fffb..0d04e0dceaee 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -89,9 +89,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index e9a95e8543be..a3df6ad612c2 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -124,9 +124,9 @@ setup_commands: pip install google-api-python-client==1.6.7 cython==0.27.3 - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-manylinux1_x86_64.whl - >- cd ~ && git clone https://github.com/ray-project/ray || true From 9948e8c11b92bc0f9f72c0402b87f8445b27c7cb Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 4 Oct 2018 07:21:04 +0800 Subject: [PATCH 010/215] Move function/actor exporting & loading code to function_manager.py (#3003) Move function/actor exporting & loading code to function_manager.py to prepare the code change for function descriptor for python. --- python/ray/actor.py | 299 ++------------------ python/ray/function_manager.py | 486 +++++++++++++++++++++++++++++++++ python/ray/import_thread.py | 55 +--- python/ray/remote_function.py | 18 +- python/ray/utils.py | 18 ++ python/ray/worker.py | 185 ++----------- 6 files changed, 559 insertions(+), 502 deletions(-) create mode 100644 python/ray/function_manager.py diff --git a/python/ray/actor.py b/python/ray/actor.py index 3886e1927a02..d61fac7d7579 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -5,31 +5,19 @@ import copy import hashlib import inspect -import json import traceback import ray.cloudpickle as pickle +from ray.function_manager import FunctionActorManager import ray.local_scheduler import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import ( - decode, - _random_string, - check_oversized_pickle, - is_cython, - push_error_to_driver, -) +from ray.utils import _random_string DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1 -def is_classmethod(f): - """Returns whether the given method is a classmethod.""" - - return hasattr(f, "__self__") and f.__self__ is not None - - def compute_actor_handle_id(actor_handle_id, num_forks): """Deterministically compute an actor handle ID. @@ -96,24 +84,6 @@ def compute_actor_creation_function_id(class_id): return ray.ObjectID(class_id) -def compute_actor_method_function_id(class_name, attr): - """Get the function ID corresponding to an actor method. - - Args: - class_name (str): The class name of the actor. - attr (str): The attribute name of the method. - - Returns: - Function ID corresponding to the method. - """ - function_id_hash = hashlib.sha1() - function_id_hash.update(class_name.encode("ascii")) - function_id_hash.update(attr.encode("ascii")) - function_id = function_id_hash.digest() - assert len(function_id) == ray_constants.ID_SIZE - return ray.ObjectID(function_id) - - def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, frontier): """Set the most recent checkpoint associated with a given actor ID. @@ -134,28 +104,6 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint, }) -def get_actor_checkpoint(worker, actor_id): - """Get the most recent checkpoint associated with a given actor ID. - - Args: - worker: The worker to use to get the checkpoint. - actor_id: The actor ID of the actor to get the checkpoint for. - - Returns: - If a checkpoint exists, this returns a tuple of the number of tasks - included in the checkpoint, the saved checkpoint state, and the - task frontier at the time of the checkpoint. If no checkpoint - exists, all objects are set to None. The checkpoint index is the . - executed on the actor before the checkpoint was made. - """ - actor_key = b"Actor:" + actor_id - checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( - actor_key, ["checkpoint_index", "checkpoint", "frontier"]) - if checkpoint_index is not None: - checkpoint_index = int(checkpoint_index) - return checkpoint_index, checkpoint, frontier - - def save_and_log_checkpoint(worker, actor): """Save a checkpoint on the actor and log any errors. @@ -205,219 +153,26 @@ def restore_and_log_checkpoint(worker, actor): return checkpoint_resumed -def make_actor_method_executor(worker, method_name, method, actor_imported): - """Make an executor that wraps a user-defined actor method. - - The wrapped method updates the worker's internal state and performs any - necessary checkpointing operations. +def get_actor_checkpoint(worker, actor_id): + """Get the most recent checkpoint associated with a given actor ID. Args: - worker (Worker): The worker that is executing the actor. - method_name (str): The name of the actor method. - method (instancemethod): The actor method to wrap. This should be a - method defined on the actor class and should therefore take an - instance of the actor as the first argument. - actor_imported (bool): Whether the actor has been imported. - Checkpointing operations will not be run if this is set to False. + worker: The worker to use to get the checkpoint. + actor_id: The actor ID of the actor to get the checkpoint for. Returns: - A function that executes the given actor method on the worker's stored - instance of the actor. The function also updates the worker's - internal state to record the executed method. - """ - - def actor_method_executor(dummy_return_id, actor, *args): - # Update the actor's task counter to reflect the task we're about to - # execute. - worker.actor_task_counter += 1 - - # If this is the first task to execute on the actor, try to resume from - # a checkpoint. - if actor_imported and worker.actor_task_counter == 1: - checkpoint_resumed = restore_and_log_checkpoint(worker, actor) - if checkpoint_resumed: - # NOTE(swang): Since we did not actually execute the __init__ - # method, this will put None as the return value. If the - # __init__ method is supposed to return multiple values, an - # exception will be logged. - return - - # Determine whether we should checkpoint the actor. - checkpointing_on = (actor_imported - and worker.actor_checkpoint_interval > 0) - # We should checkpoint the actor if user checkpointing is on, we've - # executed checkpoint_interval tasks since the last checkpoint, and the - # method we're about to execute is not a checkpoint. - save_checkpoint = ( - checkpointing_on and - (worker.actor_task_counter % worker.actor_checkpoint_interval == 0 - and method_name != "__ray_checkpoint__")) - - # Execute the assigned method and save a checkpoint if necessary. - try: - if is_classmethod(method): - method_returns = method(*args) - else: - method_returns = method(actor, *args) - except Exception: - # Save the checkpoint before allowing the method exception to be - # thrown. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - raise - else: - # Save the checkpoint before returning the method's return values. - if save_checkpoint: - save_and_log_checkpoint(worker, actor) - return method_returns - - return actor_method_executor - - -def fetch_and_register_actor(actor_class_key, worker): - """Import an actor. - - This will be called by the worker's import thread when the worker receives - the actor_class export, assuming that the worker is an actor for that - class. - - Args: - actor_class_key: The key in Redis to use to fetch the actor. - worker: The worker to use. - """ - actor_id_str = worker.actor_id - (driver_id, class_id, class_name, module, pickled_class, - checkpoint_interval, actor_method_names) = worker.redis_client.hmget( - actor_class_key, [ - "driver_id", "class_id", "class_name", "module", "class", - "checkpoint_interval", "actor_method_names" - ]) - - class_name = decode(class_name) - module = decode(module) - checkpoint_interval = int(checkpoint_interval) - actor_method_names = json.loads(decode(actor_method_names)) - - # Create a temporary actor with some temporary methods so that if the actor - # fails to be unpickled, the temporary actor can be used (just to produce - # error messages and to prevent the driver from hanging). - class TemporaryActor(object): - pass - - worker.actors[actor_id_str] = TemporaryActor() - worker.actor_checkpoint_interval = checkpoint_interval - - def temporary_actor_method(*xs): - raise Exception("The actor with name {} failed to be imported, and so " - "cannot execute this method".format(class_name)) - - # Register the actor method executors. - for actor_method_name in actor_method_names: - function_id = compute_actor_method_function_id(class_name, - actor_method_name).id() - temporary_executor = make_actor_method_executor( - worker, - actor_method_name, - temporary_actor_method, - actor_imported=False) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=temporary_executor, - function_name=actor_method_name, - max_calls=0)) - worker.num_task_executions[driver_id][function_id] = 0 - - try: - unpickled_class = pickle.loads(pickled_class) - worker.actor_class = unpickled_class - except Exception: - # If an exception was thrown when the actor was imported, we record the - # traceback and notify the scheduler of the failure. - traceback_str = ray.utils.format_error_message(traceback.format_exc()) - # Log the error message. - push_error_to_driver( - worker, - ray_constants.REGISTER_ACTOR_PUSH_ERROR, - traceback_str, - driver_id, - data={"actor_id": actor_id_str}) - # TODO(rkn): In the future, it might make sense to have the worker exit - # here. However, currently that would lead to hanging if someone calls - # ray.get on a method invoked on the actor. - else: - # TODO(pcm): Why is the below line necessary? - unpickled_class.__module__ = module - worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class) - - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - - actor_methods = inspect.getmembers(unpickled_class, predicate=pred) - for actor_method_name, actor_method in actor_methods: - function_id = compute_actor_method_function_id( - class_name, actor_method_name).id() - executor = make_actor_method_executor( - worker, actor_method_name, actor_method, actor_imported=True) - worker.function_execution_info[driver_id][function_id] = ( - ray.worker.FunctionExecutionInfo( - function=executor, - function_name=actor_method_name, - max_calls=0)) - # We do not set worker.function_properties[driver_id][function_id] - # because we currently do need the actor worker to submit new tasks - # for the actor. - - -def publish_actor_class_to_key(key, actor_class_info, worker): - """Push an actor class definition to Redis. - - The is factored out as a separate function because it is also called - on cached actor class definitions when a worker connects for the first - time. - - Args: - key: The key to store the actor class info at. - actor_class_info: Information about the actor class. - worker: The worker to use to connect to Redis. + If a checkpoint exists, this returns a tuple of the number of tasks + included in the checkpoint, the saved checkpoint state, and the + task frontier at the time of the checkpoint. If no checkpoint + exists, all objects are set to None. The checkpoint index is the . + executed on the actor before the checkpoint was made. """ - # We set the driver ID here because it may not have been available when the - # actor class was defined. - actor_class_info["driver_id"] = worker.task_driver_id.id() - worker.redis_client.hmset(key, actor_class_info) - worker.redis_client.rpush("Exports", key) - - -def export_actor_class(class_id, Class, actor_method_names, - checkpoint_interval, worker): - key = b"ActorClass:" + class_id - actor_class_info = { - "class_name": Class.__name__, - "module": Class.__module__, - "class": pickle.dumps(Class), - "checkpoint_interval": checkpoint_interval, - "actor_method_names": json.dumps(list(actor_method_names)) - } - - check_oversized_pickle(actor_class_info["class"], - actor_class_info["class_name"], "actor", worker) - - if worker.mode is None: - # This means that 'ray.init()' has not been called yet and so we must - # cache the actor class definition and export it when 'ray.init()' is - # called. - assert worker.cached_remote_functions_and_actors is not None - worker.cached_remote_functions_and_actors.append( - ("actor", (key, actor_class_info))) - # This caching code path is currently not used because we only export - # actor class definitions lazily when we instantiate the actor for the - # first time. - assert False, "This should be unreachable." - else: - publish_actor_class_to_key(key, actor_class_info, worker) - # TODO(rkn): Currently we allow actor classes to be defined within tasks. - # I tried to disable this, but it may be necessary because of - # https://github.com/ray-project/ray/issues/1146. + actor_key = b"Actor:" + actor_id + checkpoint_index, checkpoint, frontier = worker.redis_client.hmget( + actor_key, ["checkpoint_index", "checkpoint", "frontier"]) + if checkpoint_index is not None: + checkpoint_index = int(checkpoint_index) + return checkpoint_index, checkpoint, frontier def method(*args, **kwargs): @@ -518,13 +273,8 @@ def __init__(self, modified_class, class_id, checkpoint_interval, num_cpus, self._actor_method_cpus = actor_method_cpus self._exported = False - # Get the actor methods of the given class. - def pred(x): - return (inspect.isfunction(x) or inspect.ismethod(x) - or is_cython(x)) - self._actor_methods = inspect.getmembers( - self._modified_class, predicate=pred) + self._modified_class, ray.utils.is_function_or_method) # Extract the signatures of each of the methods. This will be used # to catch some errors if the methods are called with inappropriate # arguments. @@ -537,7 +287,7 @@ def pred(x): # don't support, there may not be much the user can do about it. signature.check_signature_supported(method, warn=True) self._method_signatures[method_name] = signature.extract_signature( - method, ignore_first=not is_classmethod(method)) + method, ignore_first=not ray.utils.is_class_method(method)) # Set the default number of return values for this method. if hasattr(method, "__ray_num_return_vals__"): @@ -614,9 +364,9 @@ def _submit(self, else: # Export the actor. if not self._exported: - export_actor_class(self._class_id, self._modified_class, - self._actor_method_names, - self._checkpoint_interval, worker) + worker.function_actor_manager.export_actor_class( + self._class_id, self._modified_class, + self._actor_method_names, self._checkpoint_interval) self._exported = True resources = ray.utils.resources_from_resource_arguments( @@ -801,8 +551,8 @@ def _actor_method_call(self, else: actor_handle_id = self._ray_actor_handle_id - function_id = compute_actor_method_function_id(self._ray_class_name, - method_name) + function_id = FunctionActorManager.compute_actor_method_function_id( + self._ray_class_name, method_name) object_ids = worker.submit_task( function_id, args, @@ -1068,5 +818,4 @@ def __ray_checkpoint_restore__(self): resources, actor_method_cpus) -ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor ray.worker.global_worker.make_actor = make_actor diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py new file mode 100644 index 000000000000..0e123bd67ead --- /dev/null +++ b/python/ray/function_manager.py @@ -0,0 +1,486 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import inspect +import json +import time +import traceback +from collections import ( + namedtuple, + defaultdict, +) + +import ray +from ray import profiling +from ray import ray_constants +from ray import cloudpickle as pickle +from ray.utils import ( + is_cython, + is_function_or_method, + is_class_method, + check_oversized_pickle, + decode, + format_error_message, + push_error_to_driver, +) + +FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", + ["function", "function_name", "max_calls"]) +"""FunctionExecutionInfo: A named tuple storing remote function information.""" + + +class FunctionActorManager(object): + """A class used to export/load remote functions and actors. + + Attributes: + _worker: The associated worker that this manager related. + _functions_to_export: The remote functions to export when + the worker gets connected. + _actors_to_export: The actors to export when the worker gets + connected. + _function_execution_info: The map from driver_id to finction_id + and execution_info. + _num_task_executions: The map from driver_id to function + execution times. + """ + + def __init__(self, worker): + self._worker = worker + self._functions_to_export = [] + self._actors_to_export = [] + # This field is a dictionary that maps a driver ID to a dictionary of + # functions (and information about those functions) that have been + # registered for that driver (this inner dictionary maps function IDs + # to a FunctionExecutionInfo object. This should only be used on + # workers that execute remote functions. + self._function_execution_info = defaultdict(lambda: {}) + self._num_task_executions = defaultdict(lambda: {}) + + def increase_task_counter(self, driver_id, function_id): + self._num_task_executions[driver_id][function_id] += 1 + + def get_task_counter(self, driver_id, function_id): + return self._num_task_executions[driver_id][function_id] + + def export_cached(self): + """Export cached remote functions + + Note: this should be called only once when worker is connected. + """ + for remote_function in self._functions_to_export: + self._do_export(remote_function) + self._functions_to_export = None + for info in self._actors_to_export: + (key, actor_class_info) = info + self._publish_actor_class_to_key(key, actor_class_info) + + def reset_cache(self): + self._functions_to_export = [] + self._actors_to_export = [] + + def export(self, remote_function): + """Export a remote function. + + Args: + remote_function: the RemoteFunction object. + """ + if self._worker.mode is None: + # If the worker isn't connected, cache the function + # and export it later. + self._functions_to_export.append(remote_function) + return + if self._worker.mode != ray.worker.SCRIPT_MODE: + # Don't need to export if the worker is not a driver. + return + self._do_export(remote_function) + + def _do_export(self, remote_function): + """Pickle a remote function and export it to redis. + + Args: + remote_function: the RemoteFunction object. + """ + # Work around limitations of Python pickling. + function = remote_function._function + function_name_global_valid = function.__name__ in function.__globals__ + function_name_global_value = function.__globals__.get( + function.__name__) + # Allow the function to reference itself as a global variable + if not is_cython(function): + function.__globals__[function.__name__] = remote_function + try: + pickled_function = pickle.dumps(function) + finally: + # Undo our changes + if function_name_global_valid: + function.__globals__[function.__name__] = ( + function_name_global_value) + else: + del function.__globals__[function.__name__] + + check_oversized_pickle(pickled_function, + remote_function._function_name, + "remote function", self._worker) + + key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" + + remote_function._function_id) + self._worker.redis_client.hmset( + key, { + "driver_id": self._worker.task_driver_id.id(), + "function_id": remote_function._function_id, + "name": remote_function._function_name, + "module": function.__module__, + "function": pickled_function, + "max_calls": remote_function._max_calls + }) + self._worker.redis_client.rpush("Exports", key) + + def fetch_and_register_remote_function(self, key): + """Import a remote function.""" + (driver_id, function_id_str, function_name, serialized_function, + num_return_vals, module, resources, + max_calls) = self._worker.redis_client.hmget(key, [ + "driver_id", "function_id", "name", "function", "num_return_vals", + "module", "resources", "max_calls" + ]) + function_id = ray.ObjectID(function_id_str) + function_name = decode(function_name) + max_calls = int(max_calls) + module = decode(module) + + # This is a placeholder in case the function can't be unpickled. This + # will be overwritten if the function is successfully registered. + def f(): + raise Exception("This function was not imported properly.") + + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=f, function_name=function_name, max_calls=max_calls)) + self._num_task_executions[driver_id][function_id.id()] = 0 + + try: + function = pickle.loads(serialized_function) + except Exception as e: + # If an exception was thrown when the remote function was imported, + # we record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, + traceback_str, + driver_id=driver_id, + data={ + "function_id": function_id.id(), + "function_name": function_name + }) + else: + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python script + # was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + function.__module__ = module + self._function_execution_info[driver_id][function_id.id()] = ( + FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=max_calls)) + # Add the function to the function table. + self._worker.redis_client.rpush( + b"FunctionTable:" + function_id.id(), self._worker.worker_id) + + def get_execution_info(self, driver_id, function_id): + """Get the FunctionExecutionInfo of a remote function. + + Args: + driver_id: ID of the driver that the function belongs to. + function_id: ID of the function to get. + + Returns: + A FunctionExecutionInfo object. + """ + # Wait until the function to be executed has actually been registered + # on this worker. We will push warnings to the user if we spend too + # long in this loop. + with profiling.profile("wait_for_function", worker=self._worker): + self._wait_for_function(function_id, driver_id) + return self._function_execution_info[driver_id][function_id.id()] + + def _wait_for_function(self, function_id, driver_id, timeout=10): + """Wait until the function to be executed is present on this worker. + + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate + a problem somewhere and we will push an error message to the user. + + If this worker is an actor, then this will wait until the actor has + been defined. + + Args: + function_id (str): The ID of the function that we want to execute. + driver_id (str): The ID of the driver to push the error message to + if this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with self._worker.lock: + if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID + and (function_id.id() in + self._function_execution_info[driver_id])): + break + elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and ( + self._worker.actor_id in self._worker.actors): + break + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray.") + if not warning_sent: + ray.utils.push_error_to_driver( + self._worker, + ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, + warning_message, + driver_id=driver_id) + warning_sent = True + time.sleep(0.001) + + @classmethod + def compute_actor_method_function_id(cls, class_name, attr): + """Get the function ID corresponding to an actor method. + + Args: + class_name (str): The class name of the actor. + attr (str): The attribute name of the method. + + Returns: + Function ID corresponding to the method. + """ + function_id_hash = hashlib.sha1() + function_id_hash.update(class_name.encode("ascii")) + function_id_hash.update(attr.encode("ascii")) + function_id = function_id_hash.digest() + assert len(function_id) == ray_constants.ID_SIZE + return ray.ObjectID(function_id) + + def _publish_actor_class_to_key(self, key, actor_class_info): + """Push an actor class definition to Redis. + + The is factored out as a separate function because it is also called + on cached actor class definitions when a worker connects for the first + time. + + Args: + key: The key to store the actor class info at. + actor_class_info: Information about the actor class. + worker: The worker to use to connect to Redis. + """ + # We set the driver ID here because it may not have been available when + # the actor class was defined. + actor_class_info["driver_id"] = self._worker.task_driver_id.id() + self._worker.redis_client.hmset(key, actor_class_info) + self._worker.redis_client.rpush("Exports", key) + + def export_actor_class(self, class_id, Class, actor_method_names, + checkpoint_interval): + key = b"ActorClass:" + class_id + actor_class_info = { + "class_name": Class.__name__, + "module": Class.__module__, + "class": pickle.dumps(Class), + "checkpoint_interval": checkpoint_interval, + "actor_method_names": json.dumps(list(actor_method_names)) + } + + check_oversized_pickle(actor_class_info["class"], + actor_class_info["class_name"], "actor", + self._worker) + + if self._worker.mode is None: + # This means that 'ray.init()' has not been called yet and so we + # must cache the actor class definition and export it when + # 'ray.init()' is called. + assert self._actors_to_export is not None + self._actors_to_export.append((key, actor_class_info)) + # This caching code path is currently not used because we only + # export actor class definitions lazily when we instantiate the + # actor for the first time. + assert False, "This should be unreachable." + else: + self._publish_actor_class_to_key(key, actor_class_info) + # TODO(rkn): Currently we allow actor classes to be defined + # within tasks. I tried to disable this, but it may be necessary + # because of https://github.com/ray-project/ray/issues/1146. + + def fetch_and_register_actor(self, actor_class_key): + """Import an actor. + + This will be called by the worker's import thread when the worker + receives the actor_class export, assuming that the worker is an actor + for that class. + + Args: + actor_class_key: The key in Redis to use to fetch the actor. + worker: The worker to use. + """ + actor_id_str = self._worker.actor_id + (driver_id, class_id, class_name, module, pickled_class, + checkpoint_interval, + actor_method_names) = self._worker.redis_client.hmget( + actor_class_key, [ + "driver_id", "class_id", "class_name", "module", "class", + "checkpoint_interval", "actor_method_names" + ]) + + class_name = decode(class_name) + module = decode(module) + checkpoint_interval = int(checkpoint_interval) + actor_method_names = json.loads(decode(actor_method_names)) + + # Create a temporary actor with some temporary methods so that if + # the actor fails to be unpickled, the temporary actor can be used + # (just to produce error messages and to prevent the driver from + # hanging). + class TemporaryActor(object): + pass + + self._worker.actors[actor_id_str] = TemporaryActor() + self._worker.actor_checkpoint_interval = checkpoint_interval + + def temporary_actor_method(*xs): + raise Exception( + "The actor with name {} failed to be imported, " + "and so cannot execute this method".format(class_name)) + + # Register the actor method executors. + for actor_method_name in actor_method_names: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + temporary_executor = self._make_actor_method_executor( + actor_method_name, + temporary_actor_method, + actor_imported=False) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=temporary_executor, + function_name=actor_method_name, + max_calls=0)) + self._num_task_executions[driver_id][function_id] = 0 + + try: + unpickled_class = pickle.loads(pickled_class) + self._worker.actor_class = unpickled_class + except Exception: + # If an exception was thrown when the actor was imported, we record + # the traceback and notify the scheduler of the failure. + traceback_str = ray.utils.format_error_message( + traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_ACTOR_PUSH_ERROR, + traceback_str, + driver_id, + data={"actor_id": actor_id_str}) + # TODO(rkn): In the future, it might make sense to have the worker + # exit here. However, currently that would lead to hanging if + # someone calls ray.get on a method invoked on the actor. + else: + # TODO(pcm): Why is the below line necessary? + unpickled_class.__module__ = module + self._worker.actors[actor_id_str] = unpickled_class.__new__( + unpickled_class) + + actor_methods = inspect.getmembers( + unpickled_class, predicate=is_function_or_method) + for actor_method_name, actor_method in actor_methods: + function_id = ( + FunctionActorManager.compute_actor_method_function_id( + class_name, actor_method_name).id()) + executor = self._make_actor_method_executor( + actor_method_name, actor_method, actor_imported=True) + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=executor, + function_name=actor_method_name, + max_calls=0)) + # We do not set function_properties[driver_id][function_id] + # because we currently do need the actor worker to submit new + # tasks for the actor. + + def _make_actor_method_executor(self, method_name, method, actor_imported): + """Make an executor that wraps a user-defined actor method. + + The wrapped method updates the worker's internal state and performs any + necessary checkpointing operations. + + Args: + worker (Worker): The worker that is executing the actor. + method_name (str): The name of the actor method. + method (instancemethod): The actor method to wrap. This should be a + method defined on the actor class and should therefore take an + instance of the actor as the first argument. + actor_imported (bool): Whether the actor has been imported. + Checkpointing operations will not be run if this is set to + False. + + Returns: + A function that executes the given actor method on the worker's + stored instance of the actor. The function also updates the + worker's internal state to record the executed method. + """ + + def actor_method_executor(dummy_return_id, actor, *args): + # Update the actor's task counter to reflect the task we're about + # to execute. + self._worker.actor_task_counter += 1 + + # If this is the first task to execute on the actor, try to resume + # from a checkpoint. + if actor_imported and self._worker.actor_task_counter == 1: + checkpoint_resumed = ray.actor.restore_and_log_checkpoint( + self._worker, actor) + if checkpoint_resumed: + # NOTE(swang): Since we did not actually execute the + # __init__ method, this will put None as the return value. + # If the __init__ method is supposed to return multiple + # values, an exception will be logged. + return + + # Determine whether we should checkpoint the actor. + checkpointing_on = (actor_imported + and self._worker.actor_checkpoint_interval > 0) + # We should checkpoint the actor if user checkpointing is on, we've + # executed checkpoint_interval tasks since the last checkpoint, and + # the method we're about to execute is not a checkpoint. + save_checkpoint = (checkpointing_on + and (self._worker.actor_task_counter % + self._worker.actor_checkpoint_interval == 0 + and method_name != "__ray_checkpoint__")) + + # Execute the assigned method and save a checkpoint if necessary. + try: + if is_class_method(method): + method_returns = method(*args) + else: + method_returns = method(actor, *args) + except Exception: + # Save the checkpoint before allowing the method exception + # to be thrown. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + raise + else: + # Save the checkpoint before returning the method's return + # values. + if save_checkpoint: + ray.actor.save_and_log_checkpoint(self._worker, actor) + return method_returns + + return actor_method_executor diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 659cdf1ce281..85fe0b89dfe1 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -88,7 +88,8 @@ def _process_key(self, key): if key.startswith(b"RemoteFunction"): with profiling.profile( "register_remote_function", worker=self.worker): - self.fetch_and_register_remote_function(key) + (self.worker.function_actor_manager. + fetch_and_register_remote_function(key)) elif key.startswith(b"FunctionsToRun"): with profiling.profile( "fetch_and_run_function", worker=self.worker): @@ -103,58 +104,6 @@ def _process_key(self, key): else: raise Exception("This code should be unreachable.") - def fetch_and_register_remote_function(self, key): - """Import a remote function.""" - from ray.worker import FunctionExecutionInfo - (driver_id, function_id_str, function_name, serialized_function, - num_return_vals, module, resources, - max_calls) = self.redis_client.hmget(key, [ - "driver_id", "function_id", "name", "function", "num_return_vals", - "module", "resources", "max_calls" - ]) - function_id = ray.ObjectID(function_id_str) - function_name = utils.decode(function_name) - max_calls = int(max_calls) - module = utils.decode(module) - - # This is a placeholder in case the function can't be unpickled. This - # will be overwritten if the function is successfully registered. - def f(): - raise Exception("This function was not imported properly.") - - self.worker.function_execution_info[driver_id][function_id.id()] = ( - FunctionExecutionInfo( - function=f, function_name=function_name, max_calls=max_calls)) - self.worker.num_task_executions[driver_id][function_id.id()] = 0 - - try: - function = pickle.loads(serialized_function) - except Exception: - # If an exception was thrown when the remote function was imported, - # we record the traceback and notify the scheduler of the failure. - traceback_str = utils.format_error_message(traceback.format_exc()) - # Log the error message. - utils.push_error_to_driver( - self.worker, - ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, - traceback_str, - driver_id=driver_id, - data={ - "function_id": function_id.id(), - "function_name": function_name - }) - else: - # TODO(rkn): Why is the below line necessary? - function.__module__ = module - self.worker.function_execution_info[driver_id][ - function_id.id()] = (FunctionExecutionInfo( - function=function, - function_name=function_name, - max_calls=max_calls)) - # Add the function to the function table. - self.redis_client.rpush(b"FunctionTable:" + function_id.id(), - self.worker.worker_id) - def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" (driver_id, serialized_function, diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 287d3d045539..b96f5d7e7126 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -22,7 +22,7 @@ def compute_function_id(function): func: The actual function. Returns: - This returns the function ID. + Raw bytes of the function id """ function_id_hash = hashlib.sha1() # Include the function module and name in the hash. @@ -39,8 +39,6 @@ def compute_function_id(function): # Compute the function ID. function_id = function_id_hash.digest() assert len(function_id) == ray_constants.ID_SIZE - function_id = ray.ObjectID(function_id) - return function_id @@ -72,7 +70,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, # TODO(rkn): We store the function ID as a string, so that # RemoteFunction objects can be pickled. We should undo this when # we allow ObjectIDs to be pickled. - self._function_id = compute_function_id(self._function).id() + self._function_id = compute_function_id(function) self._function_name = ( self._function.__module__ + '.' + self._function.__name__) self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS @@ -90,11 +88,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, # # Export the function. worker = ray.worker.get_global_worker() - if worker.mode == ray.worker.SCRIPT_MODE: - self._export() - elif worker.mode is None: - worker.cached_remote_functions_and_actors.append( - ("remote_function", self)) + worker.function_actor_manager.export(self) def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -141,9 +135,3 @@ def _submit(self, return object_ids[0] elif len(object_ids) > 1: return object_ids - - def _export(self): - worker = ray.worker.get_global_worker() - worker.export_remote_function( - ray.ObjectID(self._function_id), self._function_name, - self._function, self._max_calls, self) diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f6adaea9868..55f85c8ac519 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -5,6 +5,7 @@ import binascii import functools import hashlib +import inspect import numpy as np import os import subprocess @@ -144,6 +145,23 @@ def check_cython(x): (hasattr(obj, "__func__") and check_cython(obj.__func__)) +def is_function_or_method(obj): + """Check if an object is a function or method. + + Args: + obj: The Python object in question. + + Returns: + True if the object is an function or method. + """ + return (inspect.isfunction(obj) or inspect.ismethod(obj) or is_cython(obj)) + + +def is_class_method(f): + """Returns whether the given method is a class_method.""" + return hasattr(f, "__self__") and f.__self__ is not None + + def random_string(): """Generate a random string to use as an ID. diff --git a/python/ray/worker.py b/python/ray/worker.py index c0714b1fc4b7..5299c2d07b3a 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -3,7 +3,6 @@ from __future__ import print_function import atexit -import collections import colorama import hashlib import inspect @@ -33,6 +32,7 @@ import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling +from ray.function_manager import FunctionActorManager from ray.utils import ( binary_to_hex, check_oversized_pickle, @@ -176,11 +176,6 @@ def __str__(self): self.task_error)) -FunctionExecutionInfo = collections.namedtuple( - "FunctionExecutionInfo", ["function", "function_name", "max_calls"]) -"""FunctionExecutionInfo: A named tuple storing remote function information.""" - - class Worker(object): """A class used to define the control flow of a worker process. @@ -189,19 +184,9 @@ class Worker(object): functions outside of this class are considered exposed. Attributes: - function_execution_info (Dict[str, FunctionExecutionInfo]): A - dictionary mapping the name of a remote function to the remote - function itself. This is the set of remote functions that can be - executed by this worker. connected (bool): True if Ray has been started and False otherwise. mode: The mode of the worker. One of SCRIPT_MODE, LOCAL_MODE, and WORKER_MODE. - cached_remote_functions_and_actors: A list of information for exporting - remote functions and actor classes definitions that were defined - before the worker called connect. When the worker eventually does - call connect, if it is a driver, it will export these functions and - actors. If cached_remote_functions_and_actors is None, that means - that connect has been called already. cached_functions_to_run (List): A list of functions to run on all of the workers that should be exported as soon as connect is called. profiler: the profiler used to aggregate profiling information. @@ -216,24 +201,15 @@ class Worker(object): def __init__(self): """Initialize a Worker object.""" - # This field is a dictionary that maps a driver ID to a dictionary of - # functions (and information about those functions) that have been - # registered for that driver (this inner dictionary maps function IDs - # to a FunctionExecutionInfo object. This should only be used on - # workers that execute remote functions. - self.function_execution_info = collections.defaultdict(lambda: {}) # This is a dictionary mapping driver ID to a dictionary that maps # remote function IDs for that driver to a counter of the number of # times that remote function has been executed on this worker. The # counter is incremented every time the function is executed on this # worker. When the counter reaches the maximum number of executions # allowed for a particular function, the worker is killed. - self.num_task_executions = collections.defaultdict(lambda: {}) self.connected = False self.mode = None - self.cached_remote_functions_and_actors = [] self.cached_functions_to_run = [] - self.fetch_and_register_actor = None self.actor_init_error = None self.make_actor = None self.actors = {} @@ -255,6 +231,7 @@ def __init__(self): self.serialization_context_map = {} # Identity of the driver that this worker is processing. self.task_driver_id = None + self.function_actor_manager = FunctionActorManager(self) def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -674,57 +651,6 @@ def submit_task(self, return task.returns() - def export_remote_function(self, function_id, function_name, function, - max_calls, decorated_function): - """Export a remote function. - - Args: - function_id: The ID of the function. - function_name: The name of the function. - function: The raw undecorated function to export. - max_calls: The maximum number of times a given worker can execute - this function before exiting. - decorated_function: The decorated function (this is used to enable - the remote function to recursively call itself). - """ - if self.mode != SCRIPT_MODE: - raise Exception("export_remote_function can only be called on a " - "driver.") - - key = (b"RemoteFunction:" + self.task_driver_id.id() + b":" + - function_id.id()) - - # Work around limitations of Python pickling. - function_name_global_valid = function.__name__ in function.__globals__ - function_name_global_value = function.__globals__.get( - function.__name__) - # Allow the function to reference itself as a global variable - if not is_cython(function): - function.__globals__[function.__name__] = decorated_function - try: - pickled_function = pickle.dumps(function) - finally: - # Undo our changes - if function_name_global_valid: - function.__globals__[function.__name__] = ( - function_name_global_value) - else: - del function.__globals__[function.__name__] - - check_oversized_pickle(pickled_function, function_name, - "remote function", self) - - self.redis_client.hmset( - key, { - "driver_id": self.task_driver_id.id(), - "function_id": function_id.id(), - "name": function_name, - "module": function.__module__, - "function": pickled_function, - "max_calls": max_calls - }) - self.redis_client.rpush("Exports", key) - def run_function_on_all_workers(self, function, run_on_other_drivers=False): """Run arbitrary code on all of the workers. @@ -783,47 +709,6 @@ def run_function_on_all_workers(self, function, # operations into a transaction (or by implementing a custom # command that does all three things). - def _wait_for_function(self, function_id, driver_id, timeout=10): - """Wait until the function to be executed is present on this worker. - - This method will simply loop until the import thread has imported the - relevant function. If we spend too long in this loop, that may indicate - a problem somewhere and we will push an error message to the user. - - If this worker is an actor, then this will wait until the actor has - been defined. - - Args: - function_id (str): The ID of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to - if this times out. - """ - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - with self.lock: - if (self.actor_id == NIL_ACTOR_ID - and (function_id.id() in - self.function_execution_info[driver_id])): - break - elif self.actor_id != NIL_ACTOR_ID and ( - self.actor_id in self.actors): - break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart " - "Ray.") - if not warning_sent: - ray.utils.push_error_to_driver( - self, - ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, - warning_message, - driver_id=driver_id) - warning_sent = True - time.sleep(0.001) - def _get_arguments_for_execution(self, function_name, serialized_args): """Retrieve the arguments for the remote function. @@ -891,7 +776,7 @@ def _store_outputs_in_objstore(self, object_ids, outputs): self.put_object(object_ids[i], outputs[i]) - def _process_task(self, task): + def _process_task(self, task, function_execution_info): """Execute a task assigned to this worker. This method deserializes a task from the scheduler, and attempts to @@ -913,10 +798,8 @@ def _process_task(self, task): return_object_ids = task.returns() if task.actor_id().id() != NIL_ACTOR_ID: dummy_return_id = return_object_ids.pop() - function_executor = self.function_execution_info[ - self.task_driver_id.id()][function_id.id()].function - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + function_executor = function_execution_info.function + function_name = function_execution_info.function_name # Get task arguments from the object store. try: @@ -926,12 +809,12 @@ def _process_task(self, task): arguments = self._get_arguments_for_execution( function_name, args) except (RayGetError, RayGetArgumentError) as e: - self._handle_process_task_failure(function_id, return_object_ids, - e, None) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, None) return except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) return @@ -950,8 +833,9 @@ def _process_task(self, task): task_exception = task.actor_id().id() == NIL_ACTOR_ID traceback_str = ray.utils.format_error_message( traceback.format_exc(), task_exception=task_exception) - self._handle_process_task_failure(function_id, return_object_ids, - e, traceback_str) + self._handle_process_task_failure(function_id, function_name, + return_object_ids, e, + traceback_str) return # Store the outputs in the local object store. @@ -966,13 +850,11 @@ def _process_task(self, task): self._store_outputs_in_objstore(return_object_ids, outputs) except Exception as e: self._handle_process_task_failure( - function_id, return_object_ids, e, + function_id, function_name, return_object_ids, e, ray.utils.format_error_message(traceback.format_exc())) - def _handle_process_task_failure(self, function_id, return_object_ids, - error, backtrace): - function_name = self.function_execution_info[self.task_driver_id.id()][ - function_id.id()].function_name + def _handle_process_task_failure(self, function_id, function_name, + return_object_ids, error, backtrace): failure_object = RayTaskError(function_name, error, backtrace) failure_objects = [ failure_object for _ in range(len(return_object_ids)) @@ -1014,7 +896,7 @@ def _become_actor(self, task): time.sleep(0.001) with self.lock: - self.fetch_and_register_actor(key, self) + self.function_actor_manager.fetch_and_register_actor(key) def _wait_for_and_process_task(self, task): """Wait for a task to be ready and process the task. @@ -1031,11 +913,8 @@ def _wait_for_and_process_task(self, task): self._become_actor(task) return - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with profiling.profile("wait_for_function", worker=self): - self._wait_for_function(function_id, driver_id) + execution_info = self.function_actor_manager.get_execution_info( + driver_id, function_id) # Execute the task. # TODO(rkn): Consider acquiring this lock with a timeout and pushing a @@ -1043,9 +922,7 @@ def _wait_for_and_process_task(self, task): # because that may indicate that the system is hanging, and it'd be # good to know where the system is hanging. with self.lock: - - function_name = (self.function_execution_info[driver_id][ - function_id.id()]).function_name + function_name = execution_info.function_name if not self.use_raylet: extra_data = { "function_name": function_name, @@ -1058,7 +935,7 @@ def _wait_for_and_process_task(self, task): "task_id": task.task_id().hex() } with profiling.profile("task", extra_data=extra_data, worker=self): - self._process_task(task) + self._process_task(task, execution_info) # In the non-raylet code path, push all of the log events to the global # state store. In the raylet code path, this is done periodically in a @@ -1067,11 +944,11 @@ def _wait_for_and_process_task(self, task): self.profiler.flush_profile_data() # Increase the task execution counter. - self.num_task_executions[driver_id][function_id.id()] += 1 + self.function_actor_manager.increase_task_counter( + driver_id, function_id.id()) - reached_max_executions = ( - self.num_task_executions[driver_id][function_id.id()] == self. - function_execution_info[driver_id][function_id.id()].max_calls) + reached_max_executions = (self.function_actor_manager.get_task_counter( + driver_id, function_id.id()) == execution_info.max_calls) if reached_max_executions: self.local_scheduler_client.disconnect() os._exit(0) @@ -2112,7 +1989,6 @@ def connect(info, error_message = "Perhaps you called ray.init twice by accident?" assert not worker.connected, error_message assert worker.cached_functions_to_run is not None, error_message - assert worker.cached_remote_functions_and_actors is not None, error_message # Initialize some fields. worker.worker_id = random_string() @@ -2350,18 +2226,9 @@ def connect(info, # Export cached functions_to_run. for function in worker.cached_functions_to_run: worker.run_function_on_all_workers(function) - # Export cached remote functions to the workers. - for cached_type, info in worker.cached_remote_functions_and_actors: - if cached_type == "remote_function": - info._export() - elif cached_type == "actor": - (key, actor_class_info) = info - ray.actor.publish_actor_class_to_key(key, actor_class_info, - worker) - else: - assert False, "This code should be unreachable." + # Export cached remote functions and actors to the workers. + worker.function_actor_manager.export_cached() worker.cached_functions_to_run = None - worker.cached_remote_functions_and_actors = None def disconnect(worker=global_worker): @@ -2372,7 +2239,7 @@ def disconnect(worker=global_worker): # tests. worker.connected = False worker.cached_functions_to_run = [] - worker.cached_remote_functions_and_actors = [] + worker.function_actor_manager.reset_cache() worker.serialization_context_map.clear() From f2dbd3096c7a454fa0c8afefaf2774541e57d18c Mon Sep 17 00:00:00 2001 From: Si-Yuan Date: Wed, 3 Oct 2018 21:08:20 -0700 Subject: [PATCH 011/215] Minor improvements and fixes in Python code. (#3022) This commit fix some small defects. 1. Remove a comment that should have been removed in #3003 2. Remove `redis_protected_mode` that is never used in `ray.init()` 3. Fix `object_id_seed` that is forgotten to be passed into `ray._init()` 4. Remove several redundant brackets. --- python/ray/scripts/scripts.py | 40 +++++++++++++++++------------------ python/ray/services.py | 9 +++++--- python/ray/worker.py | 12 +++-------- 3 files changed, 29 insertions(+), 32 deletions(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d3e9417c1a81..f8e0c5484f75 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -432,24 +432,24 @@ def stop(): "--min-workers", required=False, type=int, - help=("Override the configured min worker node count for the cluster.")) + help="Override the configured min worker node count for the cluster.") @click.option( "--max-workers", required=False, type=int, - help=("Override the configured max worker node count for the cluster.")) + help="Override the configured max worker node count for the cluster.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( "--yes", "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) + help="Don't ask for confirmation.") def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, restart_only, yes, cluster_name): if restart_only or no_restart: @@ -465,19 +465,19 @@ def create_or_update(cluster_config_file, min_workers, max_workers, no_restart, "--workers-only", is_flag=True, default=False, - help=("Only destroy the workers.")) + help="Only destroy the workers.") @click.option( "--yes", "-y", is_flag=True, default=False, - help=("Don't ask for confirmation.")) + help="Don't ask for confirmation.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def teardown(cluster_config_file, yes, workers_only, cluster_name): teardown_cluster(cluster_config_file, yes, workers_only, cluster_name) @@ -488,17 +488,17 @@ def teardown(cluster_config_file, yes, workers_only, cluster_name): "--start", is_flag=True, default=False, - help=("Start the cluster if needed.")) + help="Start the cluster if needed.") @click.option( - "--tmux", is_flag=True, default=False, help=("Run the command in tmux.")) + "--tmux", is_flag=True, default=False, help="Run the command in tmux.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( - "--new", "-N", is_flag=True, help=("Force creation of a new screen.")) + "--new", "-N", is_flag=True, help="Force creation of a new screen.") def attach(cluster_config_file, start, tmux, cluster_name, new): attach_cluster(cluster_config_file, start, tmux, cluster_name, new) @@ -512,7 +512,7 @@ def attach(cluster_config_file, start, tmux, cluster_name, new): "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def rsync_down(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=True) @@ -526,7 +526,7 @@ def rsync_down(cluster_config_file, source, target, cluster_name): "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def rsync_up(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=False) @@ -538,27 +538,27 @@ def rsync_up(cluster_config_file, source, target, cluster_name): "--stop", is_flag=True, default=False, - help=("Stop the cluster after the command finishes running.")) + help="Stop the cluster after the command finishes running.") @click.option( "--start", is_flag=True, default=False, - help=("Start the cluster if needed.")) + help="Start the cluster if needed.") @click.option( "--screen", is_flag=True, default=False, - help=("Run the command in a screen.")) + help="Run the command in a screen.") @click.option( - "--tmux", is_flag=True, default=False, help=("Run the command in tmux.")) + "--tmux", is_flag=True, default=False, help="Run the command in tmux.") @click.option( "--cluster-name", "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") @click.option( - "--port-forward", required=False, type=int, help=("Port to forward.")) + "--port-forward", required=False, type=int, help="Port to forward.") def exec_cmd(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, port_forward): assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." @@ -576,7 +576,7 @@ def exec_cmd(cluster_config_file, cmd, screen, tmux, stop, start, cluster_name, "-n", required=False, type=str, - help=("Override the configured cluster name.")) + help="Override the configured cluster name.") def get_head_ip(cluster_config_file, cluster_name): click.echo(get_head_node_ip(cluster_config_file, cluster_name)) diff --git a/python/ray/services.py b/python/ray/services.py index ab632354f6a9..f66001ba6df9 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -187,18 +187,21 @@ def cleanup(): logger.warning("Ray did not shut down properly.") -def all_processes_alive(exclude=[]): +def all_processes_alive(exclude=None): """Check if all of the processes are still alive. Args: exclude: Don't check the processes whose types are in this list. """ + + if exclude is None: + exclude = [] for process_type, processes in all_processes.items(): # Note that p.poll() returns the exit code that the process exited # with, so an exit code of None indicates that the process is still # alive. processes_alive = [p.poll() is None for p in processes] - if (not all(processes_alive) and process_type not in exclude): + if not all(processes_alive) and process_type not in exclude: logger.warning( "A process of type {} has died.".format(process_type)) return False @@ -358,7 +361,7 @@ def _compute_version_info(): ray_version = ray.__version__ python_version = ".".join(map(str, sys.version_info[:3])) pyarrow_version = pyarrow.__version__ - return (ray_version, python_version, pyarrow_version) + return ray_version, python_version, pyarrow_version def _put_version_info_in_redis(redis_client): diff --git a/python/ray/worker.py b/python/ray/worker.py index 5299c2d07b3a..a88503bf0536 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -201,12 +201,6 @@ class Worker(object): def __init__(self): """Initialize a Worker object.""" - # This is a dictionary mapping driver ID to a dictionary that maps - # remote function IDs for that driver to a counter of the number of - # times that remote function has been executed on this worker. The - # counter is incremented every time the function is executed on this - # worker. When the counter reaches the maximum number of executions - # allowed for a particular function, the worker is killed. self.connected = False self.mode = None self.cached_functions_to_run = [] @@ -1645,7 +1639,6 @@ def init(redis_address=None, ignore_reinit_error=False, num_redis_shards=None, redis_max_clients=None, - redis_protected_mode=True, plasma_directory=None, huge_pages=False, include_webui=True, @@ -1761,6 +1754,7 @@ def init(redis_address=None, address_info=info, start_ray_local=(redis_address is None), num_workers=num_workers, + object_id_seed=object_id_seed, local_mode=local_mode, driver_mode=driver_mode, redirect_worker_output=redirect_worker_output, @@ -1823,9 +1817,9 @@ def shutdown(worker=global_worker): # besides possibly the worker itself. for process_type, processes in services.all_processes.items(): if process_type == services.PROCESS_TYPE_WORKER: - assert (len(processes)) <= 1 + assert len(processes) <= 1 else: - assert (len(processes) == 0) + assert len(processes) == 0 worker.set_mode(None) From 01bb0735698c4fffde78f90e05f96f1ebd43c8f9 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 4 Oct 2018 00:06:35 -0700 Subject: [PATCH 012/215] Suppress errors when worker or driver intentionally disconnects. (#2935) --- .../format/local_scheduler.fbs | 5 ++- src/local_scheduler/local_scheduler_client.cc | 11 ++++-- src/ray/raylet/format/node_manager.fbs | 7 ++-- src/ray/raylet/node_manager.cc | 34 +++++++++++++------ src/ray/raylet/node_manager.h | 3 +- test/failure_test.py | 13 +++++++ 6 files changed, 56 insertions(+), 17 deletions(-) diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs index ffdf13d6aea4..a23bb28f05f3 100644 --- a/src/local_scheduler/format/local_scheduler.fbs +++ b/src/local_scheduler/format/local_scheduler.fbs @@ -17,9 +17,12 @@ enum MessageType:int { // Send a reply confirming the successful registration of a worker or driver. // This is sent from the local scheduler to a worker or driver. RegisterClientReply, - // Notify the local scheduler that this client is disconnecting gracefully. + // Notify the local scheduler that this client disconnected unexpectedly. // This is sent from a worker to a local scheduler. DisconnectClient, + // Notify the local scheduler that this client is disconnecting gracefully. + // This is sent from a worker to a local scheduler. + IntentionalDisconnectClient, // Get a new task from the local scheduler. This is sent from a worker to a // local scheduler. GetTask, diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 91b5fa9c9df1..09bda7f5bd8d 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -56,8 +56,15 @@ void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) { flatbuffers::FlatBufferBuilder fbb; auto message = ray::local_scheduler::protocol::CreateDisconnectClient(fbb); fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::DisconnectClient), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); + if (conn->use_raylet) { + write_message(conn->conn, static_cast( + MessageType::IntentionalDisconnectClient), + fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); + } else { + write_message(conn->conn, + static_cast(MessageType::DisconnectClient), + fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); + } } void local_scheduler_log_event(LocalSchedulerConnection *conn, diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 4ede8f2b3e70..72f934a727b4 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -23,9 +23,12 @@ enum MessageType:int { // Send a reply confirming the successful registration of a worker or driver. // This is sent from the local scheduler to a worker or driver. RegisterClientReply, - // Notify the local scheduler that this client is disconnecting gracefully. + // Notify the local scheduler that this client is disconnecting unexpectedly. // This is sent from a worker to a local scheduler. DisconnectClient, + // Notify the local scheduler that this client is disconnecting gracefully. + // This is sent from a worker to a local scheduler. + IntentionalDisconnectClient, // Get a new task from the local scheduler. This is sent from a worker to a // local scheduler. GetTask, @@ -183,7 +186,7 @@ table PushErrorRequest { } table FreeObjectsRequest { - // Whether keep this request with local object store + // Whether keep this request with local object store // or send it to all the object stores. local_only: bool; // List of object ids we'll delete from object store. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fcc60e030082..2d0bcf86144a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -26,6 +26,8 @@ RAY_CHECK_ENUM(protocol::MessageType::RegisterClientReply, local_scheduler_protocol::MessageType::RegisterClientReply); RAY_CHECK_ENUM(protocol::MessageType::DisconnectClient, local_scheduler_protocol::MessageType::DisconnectClient); +RAY_CHECK_ENUM(protocol::MessageType::IntentionalDisconnectClient, + local_scheduler_protocol::MessageType::IntentionalDisconnectClient); RAY_CHECK_ENUM(protocol::MessageType::GetTask, local_scheduler_protocol::MessageType::GetTask); RAY_CHECK_ENUM(protocol::MessageType::ExecuteTask, @@ -539,18 +541,19 @@ void NodeManager::ProcessClientMessage( RAY_LOG(DEBUG) << "Message of type " << message_type; auto registered_worker = worker_pool_.GetRegisteredWorker(client); + auto message_type_value = static_cast(message_type); if (registered_worker && registered_worker->IsDead()) { // For a worker that is marked as dead (because the driver has died already), // all the messages are ignored except DisconnectClient. - if (static_cast(message_type) != - protocol::MessageType::DisconnectClient) { + if ((message_type_value != protocol::MessageType::DisconnectClient) && + (message_type_value != protocol::MessageType::IntentionalDisconnectClient)) { // Listen for more messages. client->ProcessMessages(); return; } } - switch (static_cast(message_type)) { + switch (message_type_value) { case protocol::MessageType::RegisterClientRequest: { ProcessRegisterClientRequestMessage(client, message_data); } break; @@ -563,6 +566,12 @@ void NodeManager::ProcessClientMessage( // because it's already disconnected. return; } break; + case protocol::MessageType::IntentionalDisconnectClient: { + ProcessDisconnectClientMessage(client, /* push_warning = */ false); + // We don't need to receive future messages from this client, + // because it's already disconnected. + return; + } break; case protocol::MessageType::SubmitTask: { ProcessSubmitTaskMessage(message_data); } break; @@ -638,7 +647,7 @@ void NodeManager::ProcessGetTaskMessage( } void NodeManager::ProcessDisconnectClientMessage( - const std::shared_ptr &client) { + const std::shared_ptr &client, bool push_warning) { const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); const std::shared_ptr driver = worker_pool_.GetRegisteredDriver(client); // This client can't be a worker and a driver. @@ -678,13 +687,16 @@ void NodeManager::ProcessDisconnectClientMessage( TreatTaskAsFailed(spec); const JobID &job_id = worker->GetAssignedDriverId(); - // TODO(rkn): Define this constant somewhere else. - std::string type = "worker_died"; - std::ostringstream error_message; - error_message << "A worker died or was killed while executing task " << task_id - << "."; - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - job_id, type, error_message.str(), current_time_ms())); + + if (push_warning) { + // TODO(rkn): Define this constant somewhere else. + std::string type = "worker_died"; + std::ostringstream error_message; + error_message << "A worker died or was killed while executing task " << task_id + << "."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + job_id, type, error_message.str(), current_time_ms())); + } } worker_pool_.DisconnectWorker(worker); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index b6b23223c797..e3d2ca1416ce 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -286,9 +286,10 @@ class NodeManager { /// client. /// /// \param client The client that sent the message. + /// \param push_warning Propogate error message if true. /// \return Void. void ProcessDisconnectClientMessage( - const std::shared_ptr &client); + const std::shared_ptr &client, bool push_warning = true); /// Process client message of SubmitTask /// diff --git a/test/failure_test.py b/test/failure_test.py index 149d4c74d485..3f631ae59a2e 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -371,6 +371,19 @@ def getpid(self): ray.get(task2) +def test_actor_scope_or_intentionally_killed_message(ray_start_regular): + @ray.remote + class Actor(object): + pass + + a = Actor.remote() + a = Actor.remote() + a.__ray_terminate__.remote() + time.sleep(1) + assert len(ray.error_info()) == 0, ( + "Should not have propogated an error - {}".format(ray.error_info())) + + @pytest.fixture def ray_start_object_store_memory(): # Start the Ray processes. From faa31ae0185d9b67a2bb115ee28a22c0735ad987 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 4 Oct 2018 10:35:39 -0700 Subject: [PATCH 013/215] Introduce concept of resources required for placing a task. (#2837) * Introduce concept of resources required for placement. * Add placement resources to task spec * Update java worker * Update taskinfo.java --- .../org/ray/runtime/AbstractRayRuntime.java | 1 - .../org/ray/runtime/generated/TaskInfo.java | 23 ++-- .../ray/runtime/raylet/RayletClientImpl.java | 8 +- python/ray/actor.py | 13 +- python/ray/worker.py | 12 +- src/common/format/common.fbs | 3 + src/common/lib/python/common_extension.cc | 130 +++++++++++------- src/ray/raylet/node_manager.cc | 17 ++- src/ray/raylet/scheduling_policy.cc | 35 ++--- src/ray/raylet/scheduling_resources.cc | 55 +++++--- src/ray/raylet/task_spec.cc | 19 ++- src/ray/raylet/task_spec.h | 70 ++++++++-- src/ray/raylet/worker_pool_test.cc | 2 +- test/failure_test.py | 19 +++ 14 files changed, 285 insertions(+), 122 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index b035f3b52bc0..ddbb93c70763 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -286,4 +286,3 @@ public FunctionManager getFunctionManager() { return functionManager; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java index 8c0512afbc4f..01113096036f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/TaskInfo.java @@ -48,9 +48,12 @@ public final class TaskInfo extends Table { public ResourcePair requiredResources(int j) { return requiredResources(new ResourcePair(), j); } public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; } - public int language() { int o = __offset(32); return o != 0 ? bb.getInt(o + bb_pos) : 0; } - public String functionDescriptor(int j) { int o = __offset(34); return o != 0 ? __string(__vector(o) + j * 4) : null; } - public int functionDescriptorLength() { int o = __offset(34); return o != 0 ? __vector_len(o) : 0; } + public ResourcePair requiredPlacementResources(int j) { return requiredPlacementResources(new ResourcePair(), j); } + public ResourcePair requiredPlacementResources(ResourcePair obj, int j) { int o = __offset(32); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } + public int requiredPlacementResourcesLength() { int o = __offset(32); return o != 0 ? __vector_len(o) : 0; } + public int language() { int o = __offset(34); return o != 0 ? bb.getInt(o + bb_pos) : 0; } + public String functionDescriptor(int j) { int o = __offset(36); return o != 0 ? __string(__vector(o) + j * 4) : null; } + public int functionDescriptorLength() { int o = __offset(36); return o != 0 ? __vector_len(o) : 0; } public static int createTaskInfo(FlatBufferBuilder builder, int driver_idOffset, @@ -67,11 +70,13 @@ public static int createTaskInfo(FlatBufferBuilder builder, int argsOffset, int returnsOffset, int required_resourcesOffset, + int required_placement_resourcesOffset, int language, int function_descriptorOffset) { - builder.startObject(16); + builder.startObject(17); TaskInfo.addFunctionDescriptor(builder, function_descriptorOffset); TaskInfo.addLanguage(builder, language); + TaskInfo.addRequiredPlacementResources(builder, required_placement_resourcesOffset); TaskInfo.addRequiredResources(builder, required_resourcesOffset); TaskInfo.addReturns(builder, returnsOffset); TaskInfo.addArgs(builder, argsOffset); @@ -89,7 +94,7 @@ public static int createTaskInfo(FlatBufferBuilder builder, return TaskInfo.endTaskInfo(builder); } - public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(16); } + public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(17); } public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); } public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); } public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); } @@ -110,8 +115,11 @@ public static int createTaskInfo(FlatBufferBuilder builder, public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(13, requiredResourcesOffset, 0); } public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } - public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(14, language, 0); } - public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(15, functionDescriptorOffset, 0); } + public static void addRequiredPlacementResources(FlatBufferBuilder builder, int requiredPlacementResourcesOffset) { builder.addOffset(14, requiredPlacementResourcesOffset, 0); } + public static int createRequiredPlacementResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } + public static void startRequiredPlacementResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } + public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(15, language, 0); } + public static void addFunctionDescriptor(FlatBufferBuilder builder, int functionDescriptorOffset) { builder.addOffset(16, functionDescriptorOffset, 0); } public static int createFunctionDescriptorVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); } public static void startFunctionDescriptorVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); } public static int endTaskInfo(FlatBufferBuilder builder) { @@ -136,4 +144,3 @@ public ByteBuffer returnsAsByteBuffer(int j) { return src; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 1a78f22debec..2152495045f2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -209,6 +209,11 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue()); } int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets); + + int[] requiredPlacementResourcesOffsets = new int[0]; + int requiredPlacementResourcesOffset = + fbb.createVectorOfTables(requiredPlacementResourcesOffsets); + int[] functionDescriptorOffsets = new int[]{ fbb.createString(task.functionDescriptor.className), fbb.createString(task.functionDescriptor.name), @@ -222,7 +227,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { actorCreateIdOffset, actorCreateDummyIdOffset, actorIdOffset, actorHandleIdOffset, actorCounter, false, functionIdOffset, - argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA, + argsOffset, returnsOffset, requiredResourcesOffset, + requiredPlacementResourcesOffset, TaskLanguage.JAVA, functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/python/ray/actor.py b/python/ray/actor.py index d61fac7d7579..65ddc266f944 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -373,6 +373,15 @@ def _submit(self, self._num_cpus, self._num_gpus, self._resources, num_cpus, num_gpus, resources) + # If the actor methods require CPU resources, then set the required + # placement resources. If actor_placement_resources is empty, then + # the required placement resources will be the same as resources. + actor_placement_resources = {} + assert self._actor_method_cpus in [0, 1] + if self._actor_method_cpus == 1: + actor_placement_resources = resources.copy() + actor_placement_resources["CPU"] += 1 + creation_args = [self._class_id] function_id = compute_actor_creation_function_id(self._class_id) [actor_cursor] = worker.submit_task( @@ -380,7 +389,8 @@ def _submit(self, creation_args, actor_creation_id=actor_id, num_return_vals=1, - resources=resources) + resources=resources, + placement_resources=actor_placement_resources) # We initialize the actor counter at 1 to account for the actor # creation task. @@ -566,6 +576,7 @@ def _actor_method_call(self, # We add one for the dummy return ID. num_return_vals=num_return_vals + 1, resources={"CPU": self._ray_actor_method_cpus}, + placement_resources={}, driver_id=self._ray_actor_driver_id) # Update the actor counter and cursor to reflect the most recent # invocation. diff --git a/python/ray/worker.py b/python/ray/worker.py index a88503bf0536..4739b2e7cc7c 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -550,6 +550,7 @@ def submit_task(self, execution_dependencies=None, num_return_vals=None, resources=None, + placement_resources=None, driver_id=None): """Submit a remote task to the scheduler. @@ -575,6 +576,9 @@ def submit_task(self, num_return_vals: The number of return values this function should have. resources: The resource requirements for this task. + placement_resources: The resources required for placing the task. + If this is not provided or if it is an empty dictionary, then + the placement resources will be equal to resources. driver_id: The ID of the relevant driver. This is almost always the driver ID of the driver that is currently running. However, in the exceptional case that an actor task is being dispatched to @@ -628,6 +632,9 @@ def submit_task(self, raise ValueError( "Resource quantities must all be whole numbers.") + if placement_resources is None: + placement_resources = {} + with self.state_lock: # Increment the worker's task index to track how many tasks # have been submitted by the current task so far. @@ -640,7 +647,8 @@ def submit_task(self, num_return_vals, self.current_task_id, task_index, actor_creation_id, actor_creation_dummy_object_id, actor_id, actor_handle_id, actor_counter, is_actor_checkpoint_method, - execution_dependencies, resources, self.use_raylet) + execution_dependencies, resources, placement_resources, + self.use_raylet) self.local_scheduler_client.submit(task) return task.returns() @@ -2138,7 +2146,7 @@ def connect(info, worker.current_task_id, worker.task_index, ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, False, [], {"CPU": 0}, worker.use_raylet) + nil_actor_counter, False, [], {"CPU": 0}, {}, worker.use_raylet) # Add the driver task to the task table. if not worker.use_raylet: diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index 9dc9f651a3e3..a5b2177f1c30 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -60,6 +60,9 @@ table TaskInfo { // The required_resources vector indicates the quantities of the different // resources required by this task. required_resources: [ResourcePair]; + // The resources required for placing this task on a node. If this is empty, + // then the placement resources are equal to the required_resources. + required_placement_resources: [ResourcePair]; // The language that this task belongs to language: TaskLanguage; // Function descriptor, which is a list of strings that can diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index 68965e270980..3a1a44f5e148 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -295,49 +295,100 @@ PyTypeObject PyObjectIDType = { PyType_GenericNew, /* tp_new */ }; -/* Define the PyTask class. */ +// Define the PyTask class. + +int resource_map_from_python_dict( + PyObject *resource_map, + std::unordered_map &out) { + RAY_CHECK(out.size() == 0); + + PyObject *key, *value; + Py_ssize_t position = 0; + if (!PyDict_Check(resource_map)) { + PyErr_SetString(PyExc_TypeError, "resource_map must be a dictionary"); + return -1; + } + + while (PyDict_Next(resource_map, &position, &key, &value)) { +#if PY_MAJOR_VERSION >= 3 + if (!PyUnicode_Check(key)) { + PyErr_SetString(PyExc_TypeError, + "the keys in resource_map must be strings"); + return -1; + } +#else + if (!PyBytes_Check(key)) { + PyErr_SetString(PyExc_TypeError, + "the keys in resource_map must be strings"); + return -1; + } +#endif + + // Check that the resource quantities are numbers. + if (!(PyFloat_Check(value) || PyInt_Check(value) || PyLong_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "the values in resource_map must be floats"); + return -1; + } + // Handle the case where the key is a bytes object and the case where it + // is a unicode object. + std::string resource_name; + if (PyUnicode_Check(key)) { + PyObject *ascii_key = PyUnicode_AsASCIIString(key); + resource_name = + std::string(PyBytes_AsString(ascii_key), PyBytes_Size(ascii_key)); + Py_DECREF(ascii_key); + } else { + resource_name = std::string(PyBytes_AsString(key), PyBytes_Size(key)); + } + out[resource_name] = PyFloat_AsDouble(value); + } + return 0; +} static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { - /* ID of the driver that this task originates from. */ + // ID of the driver that this task originates from. UniqueID driver_id; - /* ID of the actor this task should run on. */ + // ID of the actor this task should run on. UniqueID actor_id = ActorID::nil(); - /* ID of the actor handle used to submit this task. */ + // ID of the actor handle used to submit this task. UniqueID actor_handle_id = ActorHandleID::nil(); - /* How many tasks have been launched on the actor so far? */ + // How many tasks have been launched on the actor so far? int actor_counter = 0; - /* True if this is an actor checkpoint task and false otherwise. */ + // True if this is an actor checkpoint task and false otherwise. PyObject *is_actor_checkpoint_method_object = nullptr; - /* ID of the function this task executes. */ + // ID of the function this task executes. FunctionID function_id; - /* Arguments of the task (can be PyObjectIDs or Python values). */ + // Arguments of the task (can be PyObjectIDs or Python values). PyObject *arguments; - /* Number of return values of this task. */ + // Number of return values of this task. int num_returns; - /* The ID of the task that called this task. */ + // The ID of the task that called this task. TaskID parent_task_id; - /* The number of tasks that the parent task has called prior to this one. */ + // The number of tasks that the parent task has called prior to this one. int parent_counter; // The actor creation ID. ActorID actor_creation_id = ActorID::nil(); // The dummy object for the actor creation task (if this is an actor method). ObjectID actor_creation_dummy_object_id = ObjectID::nil(); - /* Arguments of the task that are execution-dependent. These must be - * PyObjectIDs). */ + // Arguments of the task that are execution-dependent. These must be + // PyObjectIDs). PyObject *execution_arguments = nullptr; - /* Dictionary of resource requirements for this task. */ + // Dictionary of resource requirements for this task. PyObject *resource_map = nullptr; + // Dictionary of required placement resources for this task. + PyObject *placement_resource_map = nullptr; // True if we should use the raylet code path and false otherwise. PyObject *use_raylet_object = nullptr; if (!PyArg_ParseTuple( - args, "O&O&OiO&i|O&O&O&O&iOOOO", &PyObjectToUniqueID, &driver_id, + args, "O&O&OiO&i|O&O&O&O&iOOOOO", &PyObjectToUniqueID, &driver_id, &PyObjectToUniqueID, &function_id, &arguments, &num_returns, &PyObjectToUniqueID, &parent_task_id, &parent_counter, &PyObjectToUniqueID, &actor_creation_id, &PyObjectToUniqueID, &actor_creation_dummy_object_id, &PyObjectToUniqueID, &actor_id, &PyObjectToUniqueID, &actor_handle_id, &actor_counter, &is_actor_checkpoint_method_object, &execution_arguments, - &resource_map, &use_raylet_object)) { + &resource_map, &placement_resource_map, &use_raylet_object)) { return -1; } @@ -349,48 +400,25 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { // Parse the resource map. std::unordered_map required_resources; + std::unordered_map required_placement_resources; - bool found_CPU_requirements = false; - PyObject *key, *value; - Py_ssize_t position = 0; if (resource_map != nullptr) { - if (!PyDict_Check(resource_map)) { - PyErr_SetString(PyExc_TypeError, "resource_map must be a dictionary"); + if (resource_map_from_python_dict(resource_map, required_resources) != 0) { return -1; } - while (PyDict_Next(resource_map, &position, &key, &value)) { - if (!(PyBytes_Check(key) || PyUnicode_Check(key))) { - PyErr_SetString(PyExc_TypeError, - "the keys in resource_map must be strings"); - return -1; - } - if (!(PyFloat_Check(value) || PyInt_Check(value) || - PyLong_Check(value))) { - PyErr_SetString(PyExc_TypeError, - "the values in resource_map must be floats"); - return -1; - } - // Handle the case where the key is a bytes object and the case where it - // is a unicode object. - std::string resource_name; - if (PyUnicode_Check(key)) { - PyObject *ascii_key = PyUnicode_AsASCIIString(key); - resource_name = - std::string(PyBytes_AsString(ascii_key), PyBytes_Size(ascii_key)); - Py_DECREF(ascii_key); - } else { - resource_name = std::string(PyBytes_AsString(key), PyBytes_Size(key)); - } - if (resource_name == std::string("CPU")) { - found_CPU_requirements = true; - } - required_resources[resource_name] = PyFloat_AsDouble(value); - } } - if (!found_CPU_requirements) { + + if (required_resources.count("CPU") == 0) { required_resources["CPU"] = 1.0; } + if (placement_resource_map != nullptr) { + if (resource_map_from_python_dict(placement_resource_map, + required_placement_resources) != 0) { + return -1; + } + } + Py_ssize_t num_args = PyList_Size(arguments); bool use_raylet = false; @@ -463,7 +491,7 @@ static int PyTask_init(PyTask *self, PyObject *args, PyObject *kwds) { driver_id, parent_task_id, parent_counter, actor_creation_id, actor_creation_dummy_object_id, actor_id, actor_handle_id, actor_counter, function_id, args, num_returns, required_resources, - Language::PYTHON); + required_placement_resources, Language::PYTHON); } /* Set the task's execution dependencies. */ diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 2d0bcf86144a..c4e75dbb8c6a 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -945,17 +945,20 @@ void NodeManager::ScheduleTasks( // TODO(rkn): Define this constant somewhere else. std::string type = "infeasible_task"; std::ostringstream error_message; - error_message << "The task with ID " << task.GetTaskSpecification().TaskId() - << " is infeasible and cannot currently be executed. " - << "It requested " - << task.GetTaskSpecification().GetRequiredResources().ToString(); + error_message + << "The task with ID " << task.GetTaskSpecification().TaskId() + << " is infeasible and cannot currently be executed. It requires " + << task.GetTaskSpecification().GetRequiredResources().ToString() + << " for execution and " + << task.GetTaskSpecification().GetRequiredPlacementResources().ToString() + << " for placement. Check the client table to view node resources."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( task.GetTaskSpecification().DriverId(), type, error_message.str(), current_time_ms())); } // Assert that this placeable task is not feasible locally (necessary but not // sufficient). - RAY_CHECK(!task.GetTaskSpecification().GetRequiredResources().IsSubset( + RAY_CHECK(!task.GetTaskSpecification().GetRequiredPlacementResources().IsSubset( cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] .GetTotalResources())); } @@ -1009,7 +1012,7 @@ void NodeManager::TreatTaskAsFailed(const TaskSpecification &spec) { void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage, bool forwarded) { - const TaskID task_id = task.GetTaskSpecification().TaskId(); + const TaskID &task_id = task.GetTaskSpecification().TaskId(); if (local_queues_.HasTask(task_id)) { RAY_LOG(WARNING) << "Submitted task " << task_id << " is already queued and will not be reconstructed. This is most " @@ -1287,6 +1290,8 @@ void NodeManager::AssignTask(Task &task) { this->cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources())); if (spec.IsActorCreationTask()) { + // Check that we are not placing an actor creation task on a node with 0 CPUs. + RAY_CHECK(cluster_resource_map_[my_client_id].GetTotalResources().GetNumCpus() != 0); worker->SetLifetimeResourceIds(acquired_resources); } else { worker->SetTaskResourceIds(acquired_resources); diff --git a/src/ray/raylet/scheduling_policy.cc b/src/ray/raylet/scheduling_policy.cc index 0ed4efb27335..eb2d41632339 100644 --- a/src/ray/raylet/scheduling_policy.cc +++ b/src/ray/raylet/scheduling_policy.cc @@ -35,12 +35,9 @@ std::unordered_map SchedulingPolicy::Schedule( // Iterate over running tasks, get their resource demand and try to schedule. for (const auto &t : scheduling_queue_.GetPlaceableTasks()) { // Get task's resource demand - const auto &resource_demand = t.GetTaskSpecification().GetRequiredResources(); - const TaskID &task_id = t.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "[SchedulingPolicy]: task=" << task_id - << " numforwards=" << t.GetTaskExecutionSpec().NumForwards() - << " resources=" - << t.GetTaskSpecification().GetRequiredResources().ToString(); + const auto &spec = t.GetTaskSpecification(); + const auto &resource_demand = spec.GetRequiredPlacementResources(); + const TaskID &task_id = spec.TaskId(); // TODO(atumanov): try to place tasks locally first. // Construct a set of viable node candidates and randomly pick between them. @@ -97,7 +94,7 @@ std::unordered_map SchedulingPolicy::Schedule( std::uniform_int_distribution distribution(0, client_keys.size() - 1); int client_key_index = distribution(gen_); const ClientID &dst_client_id = client_keys[client_key_index]; - decision[t.GetTaskSpecification().TaskId()] = dst_client_id; + decision[task_id] = dst_client_id; // Update dst_client_id's load to keep track of remote task load until // the next heartbeat. ResourceSet new_load(cluster_resources[dst_client_id].GetLoadResources()); @@ -107,9 +104,11 @@ std::unordered_map SchedulingPolicy::Schedule( // There are no nodes that can feasibly execute this task. The task remains // placeable until cluster capacity becomes available. // TODO(rkn): Propagate a warning to the user. - RAY_LOG(INFO) << "This task requires " - << t.GetTaskSpecification().GetRequiredResources().ToString() - << ", but no nodes have the necessary resources."; + RAY_LOG(INFO) << "The task with ID " << task_id << " requires " + << spec.GetRequiredResources().ToString() << " for execution and " + << spec.GetRequiredPlacementResources().ToString() + << " for placement, but no nodes have the necessary resources. " + << "Check the client table to view node resources."; } } } @@ -126,19 +125,21 @@ std::vector SchedulingPolicy::SpillOver( // Check if we can accommodate an infeasible task. for (const auto &task : scheduling_queue_.GetInfeasibleTasks()) { - if (task.GetTaskSpecification().GetRequiredResources().IsSubset( + const auto &spec = task.GetTaskSpecification(); + if (spec.GetRequiredPlacementResources().IsSubset( remote_scheduling_resources.GetTotalResources())) { - decision.push_back(task.GetTaskSpecification().TaskId()); - new_load.AddResources(task.GetTaskSpecification().GetRequiredResources()); + decision.push_back(spec.TaskId()); + new_load.AddResources(spec.GetRequiredResources()); } } for (const auto &task : scheduling_queue_.GetReadyTasks()) { - if (!task.GetTaskSpecification().IsActorTask()) { - if (task.GetTaskSpecification().GetRequiredResources().IsSubset( + const auto &spec = task.GetTaskSpecification(); + if (!spec.IsActorTask()) { + if (spec.GetRequiredPlacementResources().IsSubset( remote_scheduling_resources.GetTotalResources())) { - decision.push_back(task.GetTaskSpecification().TaskId()); - new_load.AddResources(task.GetTaskSpecification().GetRequiredResources()); + decision.push_back(spec.TaskId()); + new_load.AddResources(spec.GetRequiredResources()); break; } } diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index 49519c493290..e85f3eaa7fa6 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -17,7 +17,7 @@ ResourceSet::ResourceSet(const std::vector &resource_labels, const std::vector resource_capacity) { RAY_CHECK(resource_labels.size() == resource_capacity.size()); for (uint i = 0; i < resource_labels.size(); i++) { - RAY_CHECK(this->AddResource(resource_labels[i], resource_capacity[i])); + RAY_CHECK(AddResource(resource_labels[i], resource_capacity[i])); } } @@ -119,11 +119,11 @@ bool ResourceSet::GetResource(const std::string &resource_name, double *value) c if (!value) { return false; } - if (this->resource_capacity_.count(resource_name) == 0) { + if (resource_capacity_.count(resource_name) == 0) { *value = std::nan(""); return false; } - *value = this->resource_capacity_.at(resource_name); + *value = resource_capacity_.at(resource_name); return true; } @@ -135,15 +135,25 @@ double ResourceSet::GetNumCpus() const { const std::string ResourceSet::ToString() const { std::string return_string = ""; - for (const auto &resource_pair : this->resource_capacity_) { - return_string += - "{" + resource_pair.first + "," + std::to_string(resource_pair.second) + "}, "; + + auto it = resource_capacity_.begin(); + + // Convert the first element to a string. + if (it != resource_capacity_.end()) { + return_string += "{" + it->first + "," + std::to_string(it->second) + "}"; + } + it++; + + // Add the remaining elements to the string (along with a comma). + for (; it != resource_capacity_.end(); ++it) { + return_string += ",{" + it->first + "," + std::to_string(it->second) + "}"; } + return return_string; } const std::unordered_map &ResourceSet::GetResourceMap() const { - return this->resource_capacity_; + return resource_capacity_; }; /// ResourceIds class implementation @@ -400,11 +410,20 @@ ResourceSet ResourceIdSet::ToResourceSet() const { std::string ResourceIdSet::ToString() const { std::string return_string = "AvailableResources: "; - for (auto const &resource_pair : available_resources_) { - return_string += resource_pair.first + ": {"; - return_string += resource_pair.second.ToString(); - return_string += "}, "; + + auto it = available_resources_.begin(); + + // Convert the first element to a string. + if (it != available_resources_.end()) { + return_string += (it->first + ": {" + it->second.ToString() + "}"); } + it++; + + // Add the remaining elements to the string (along with a comma). + for (; it != available_resources_.end(); ++it) { + return_string += (", " + it->first + ": {" + it->second.ToString() + "}"); + } + return return_string; } @@ -450,26 +469,26 @@ SchedulingResources::~SchedulingResources() {} ResourceAvailabilityStatus SchedulingResources::CheckResourcesSatisfied( ResourceSet &resources) const { - if (!resources.IsSubset(this->resources_total_)) { + if (!resources.IsSubset(resources_total_)) { return ResourceAvailabilityStatus::kInfeasible; } // Resource demand specified is feasible. Check if it's available. - if (!resources.IsSubset(this->resources_available_)) { + if (!resources.IsSubset(resources_available_)) { return ResourceAvailabilityStatus::kResourcesUnavailable; } return ResourceAvailabilityStatus::kFeasible; } const ResourceSet &SchedulingResources::GetAvailableResources() const { - return this->resources_available_; + return resources_available_; } void SchedulingResources::SetAvailableResources(ResourceSet &&newset) { - this->resources_available_ = newset; + resources_available_ = newset; } const ResourceSet &SchedulingResources::GetTotalResources() const { - return this->resources_total_; + return resources_total_; } void SchedulingResources::SetLoadResources(ResourceSet &&newset) { @@ -482,12 +501,12 @@ const ResourceSet &SchedulingResources::GetLoadResources() const { // Return specified resources back to SchedulingResources. bool SchedulingResources::Release(const ResourceSet &resources) { - return this->resources_available_.AddResourcesStrict(resources); + return resources_available_.AddResourcesStrict(resources); } // Take specified resources from SchedulingResources. bool SchedulingResources::Acquire(const ResourceSet &resources) { - return this->resources_available_.SubtractResourcesStrict(resources); + return resources_available_.SubtractResourcesStrict(resources); } } // namespace raylet diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index b9fd35f02ea5..cab87a729b63 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -50,7 +50,7 @@ TaskSpecification::TaskSpecification( : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(), ObjectID::nil(), ActorID::nil(), ActorHandleID::nil(), -1, function_id, task_arguments, num_returns, required_resources, - language) {} + std::unordered_map(), language) {} TaskSpecification::TaskSpecification( const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, @@ -59,6 +59,7 @@ TaskSpecification::TaskSpecification( const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, const Language &language) : spec_() { flatbuffers::FlatBufferBuilder fbb; @@ -99,7 +100,8 @@ TaskSpecification::TaskSpecification( to_flatbuf(fbb, actor_creation_dummy_object_id), to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, false, to_flatbuf(fbb, function_id), fbb.CreateVector(arguments), - fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources), task_language); + fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources), + map_to_flatbuf(fbb, required_placement_resources), task_language); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -179,6 +181,19 @@ const ResourceSet TaskSpecification::GetRequiredResources() const { return ResourceSet(required_resources); } +const ResourceSet TaskSpecification::GetRequiredPlacementResources() const { + auto message = flatbuffers::GetRoot(spec_.data()); + auto required_placement_resources = + map_from_flatbuf(*message->required_placement_resources()); + // If the required_placement_resources field is empty, then the placement + // resources default to the required resources. + if (required_placement_resources.size() == 0) { + required_placement_resources = map_from_flatbuf(*message->required_resources()); + } + + return ResourceSet(required_placement_resources); +} + bool TaskSpecification::IsDriverTask() const { // Driver tasks are empty tasks that have no function ID set. return FunctionId().is_nil(); diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index a4075ff6a1db..49bd02bd678a 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -83,34 +83,61 @@ class TaskSpecification { TaskSpecification(const flatbuffers::String &string); // TODO(swang): Define an actor task constructor. - /// Create a task specification from the raw fields. + /// Create a task specification from the raw fields. This constructor omits + /// some values and sets them to sensible defaults. /// /// \param driver_id The driver ID, representing the job that this task is a - /// part of. + /// part of. /// \param parent_task_id The task ID of the task that spawned this task. /// \param parent_counter The number of tasks that this task's parent spawned - /// before this task. + /// before this task. /// \param function_id The ID of the function this task should execute. - /// \param arguments The list of task arguments. + /// \param task_arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. + /// \param language The language of the worker that must execute the function. TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const FunctionID &function_id, - const std::vector> &arguments, - int64_t num_returns, - const std::unordered_map &required_resources, - const Language &language); - - TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, - int64_t parent_counter, const ActorID &actor_creation_id, - const ObjectID &actor_creation_dummy_object_id, - const ActorID &actor_id, const ActorHandleID &actor_handle_id, - int64_t actor_counter, const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const Language &language); + // TODO(swang): Define an actor task constructor. + /// Create a task specification from the raw fields. + /// + /// \param driver_id The driver ID, representing the job that this task is a + /// part of. + /// \param parent_task_id The task ID of the task that spawned this task. + /// \param parent_counter The number of tasks that this task's parent spawned + /// before this task. + /// \param actor_creation_id If this is an actor task, then this is the ID of + /// the corresponding actor creation task. Otherwise, this is nil. + /// \param actor_id The ID of the actor for the task. If this is not an actor + /// task, then this is nil. + /// \param actor_handle_id The ID of the actor handle that submitted this + /// task. If this is not an actor task, then this is nil. + /// \param actor_counter The number of tasks submitted before this task from + /// the same actor handle. If this is not an actor task, then this is 0. + /// \param function_id The ID of the function this task should execute. + /// \param task_arguments The list of task arguments. + /// \param num_returns The number of values returned by the task. + /// \param required_resources The task's resource demands. + /// \param required_placement_resources The resources required to place this + /// task on a node. Typically, this should be an empty map in which case it + /// will default to be equal to the required_resources argument. + /// \param language The language of the worker that must execute the function. + TaskSpecification( + const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, + const ActorID &actor_id, const ActorHandleID &actor_handle_id, + int64_t actor_counter, const FunctionID &function_id, + const std::vector> &task_arguments, + int64_t num_returns, + const std::unordered_map &required_resources, + const std::unordered_map &required_placement_resources, + const Language &language); + /// Deserialize a task specification from a flatbuffer's string data. /// /// \param string The string data for a serialized task specification @@ -141,7 +168,22 @@ class TaskSpecification { const uint8_t *ArgVal(int64_t arg_index) const; size_t ArgValLength(int64_t arg_index) const; double GetRequiredResource(const std::string &resource_name) const; + /// Return the resources that are to be acquired during the execution of this + /// task. + /// + /// \return The resources that will be acquired during the execution of this + /// task. const ResourceSet GetRequiredResources() const; + /// Return the resources that are required for a task to be placed on a node. + /// This will typically be the same as the resources acquired during execution + /// and will always be a superset of those resources. However, they may + /// differ, e.g., actor creation tasks may require more resources to be + /// scheduled on a machine because the actor creation task may require no + /// resources itself, but subsequent actor methods may require resources, and + /// so the placement of the actor should take this into account. + /// + /// \return The resources that are required to place a task on a node. + const ResourceSet GetRequiredPlacementResources() const; bool IsDriverTask() const; Language GetLanguage() const; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 2a7e32123c36..abaf675ff625 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -65,7 +65,7 @@ static inline TaskSpecification ExampleTaskSpec( const Language &language = Language::PYTHON) { return TaskSpecification(UniqueID::nil(), UniqueID::nil(), 0, ActorID::nil(), ObjectID::nil(), actor_id, ActorHandleID::nil(), 0, - FunctionID::nil(), {}, 0, {{}}, language); + FunctionID::nil(), {}, 0, {{}}, {{}}, language); } TEST_F(WorkerPoolTest, HandleWorkerRegistration) { diff --git a/test/failure_test.py b/test/failure_test.py index 3f631ae59a2e..9cd25962b76b 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -561,6 +561,25 @@ class Foo(object): wait_for_errors(ray_constants.INFEASIBLE_TASK_ERROR, 2) +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_warning_for_infeasible_zero_cpu_actor(shutdown_only): + # Check that we cannot place an actor on a 0 CPU machine and that we get an + # infeasibility warning (even though the actor creation task itself + # requires no CPUs). + + ray.init(num_cpus=0) + + @ray.remote + class Foo(object): + pass + + # The actor creation should be infeasible. + Foo.remote() + wait_for_errors(ray_constants.INFEASIBLE_TASK_ERROR, 1) + + @pytest.fixture def ray_start_two_nodes(): # Start the Ray processes. From 0651d3b629e3bb153d3df04a7d7f33bd6d92ef3c Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 4 Oct 2018 17:23:17 -0700 Subject: [PATCH 014/215] [tune/core] Use Global State API for resources (#3004) --- python/ray/experimental/state.py | 18 +++++++++++++----- python/ray/tune/ray_trial_executor.py | 7 ++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index d91165637b60..eab71993c60c 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -1310,9 +1310,19 @@ def cluster_resources(self): return dict(resources) + def _live_client_ids(self): + """Returns a set of client IDs corresponding to clients still alive.""" + return { + client["ClientID"] + for client in self.client_table() if client["IsInsertion"] + } + def available_resources(self): """Get the current available cluster resources. + This is different from `cluster_resources` in that this will return + idle (available) resources rather than total resources. + Note that this information can grow stale as tasks start and finish. Returns: @@ -1364,6 +1374,7 @@ def available_resources(self): if local_scheduler_id not in local_scheduler_ids: del available_resources_by_id[local_scheduler_id] else: + # TODO(rliaw): Is this a fair assumption? # Assumes the number of Redis clients does not change subscribe_clients = [ redis_client.pubsub(ignore_subscribe_messages=True) @@ -1373,7 +1384,7 @@ def available_resources(self): subscribe_client.subscribe( ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) - client_ids = {client["ClientID"] for client in self.client_table()} + client_ids = self._live_client_ids() while set(available_resources_by_id.keys()) != client_ids: for subscribe_client in subscribe_clients: @@ -1403,10 +1414,7 @@ def available_resources(self): available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster - client_ids = { - client["ClientID"] - for client in self.client_table() - } + client_ids = self._live_client_ids() # Remove disconnected clients for client_id in available_resources_by_id.keys(): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 86f09cda34d8..acbebb38b4ab 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -213,12 +213,13 @@ def _return_resources(self, resources): assert self._committed_resources.gpu >= 0 def _update_avail_resources(self): - clients = ray.global_state.client_table() if ray.worker.global_worker.use_raylet: # TODO(rliaw): Remove once raylet flag is swapped - num_cpus = sum(cl['Resources']['CPU'] for cl in clients) - num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients) + resources = ray.global_state.cluster_resources() + num_cpus = resources["CPU"] + num_gpus = resources["GPU"] else: + clients = ray.global_state.client_table() local_schedulers = [ entry for client in clients.values() for entry in client if (entry['ClientType'] == 'local_scheduler' From ecd8f39580c40da83be4bca34612391f824b23d5 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 5 Oct 2018 15:24:24 -0700 Subject: [PATCH 015/215] [core] Improve logging message when plasma store is started. (#3029) Improve logging message when plasma store is started. --- python/ray/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/services.py b/python/ray/services.py index f66001ba6df9..9b1592e7bc29 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1148,7 +1148,7 @@ def start_plasma_store(node_ip_address, objstore_memory = int(system_memory * 0.8) # Start the Plasma store. logger.info("Starting the Plasma object store with {0:.2f} GB memory." - .format(objstore_memory // 10**9)) + .format(objstore_memory / 10**9)) plasma_store_name, p1 = ray.plasma.start_plasma_store( plasma_store_memory=objstore_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, From 2d35a97a7628e2a7c1b4f313c690626c1d46c529 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sat, 6 Oct 2018 00:34:53 -0700 Subject: [PATCH 016/215] Bug/log syncer fails with parentheses (#2653) * Update rsync command * Escape rsync locations * Fix the accidental variable move * Update rsync to use -s flag --- python/ray/tune/log_sync.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/log_sync.py b/python/ray/tune/log_sync.py index 2e18a8658208..109c11a01707 100644 --- a/python/ray/tune/log_sync.py +++ b/python/ray/tune/log_sync.py @@ -107,11 +107,13 @@ def sync_now(self, force=False): if not distutils.spawn.find_executable("rsync"): logger.error("Log sync requires rsync to be installed.") return + source = '{}@{}:{}/'.format(ssh_user, self.worker_ip, + self.local_dir) + target = '{}/'.format(self.local_dir) worker_to_local_sync_cmd = (( - """rsync -avz -e "ssh -i {} -o ConnectTimeout=120s """ - """-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format( - quote(ssh_key), ssh_user, self.worker_ip, - quote(self.local_dir), quote(self.local_dir))) + """rsync -savz -e "ssh -i {} -o ConnectTimeout=120s """ + """-o StrictHostKeyChecking=no" {} {}""").format( + quote(ssh_key), quote(source), quote(target))) if self.remote_dir: if self.remote_dir.startswith(S3_PREFIX): From 84bf5fc8f35c7ed46283d8c9b11e6b85df41d677 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 9 Oct 2018 04:05:26 +0800 Subject: [PATCH 017/215] [Java] Load driver resources from local path. (#3001) ## What do these changes do? 1. Add a configuration item `driver.resource-path`. 2. Load driver resources from the local path which is specified in the `ray.conf`. Before this change, we should add all driver resources(like user's jar package, dependencies package and config files) into `classpath`. After this change, we should add the driver resources into the mount path which we can configure it in `ray.conf`, and we shouldn't configure `classpath` for driver resources any more. ## Related issue number N/A --- .../org/ray/runtime/AbstractRayRuntime.java | 2 +- .../org/ray/runtime/config/RayConfig.java | 13 +++++++ .../functionmanager/FunctionManager.java | 35 +++++++++++++++++-- .../java/org/ray/runtime/util/JarLoader.java | 7 ++-- .../src/main/resources/ray.default.conf | 13 +++++-- .../functionmanager/FunctionManagerTest.java | 31 +++++++++++++++- .../java/org/ray/api/test/RayConfigTest.java | 5 +++ 7 files changed, 97 insertions(+), 9 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index ddbb93c70763..407f142d9d16 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -48,7 +48,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - functionManager = new FunctionManager(); + functionManager = new FunctionManager(rayConfig.driverResourcePath); worker = new Worker(this); workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.driverId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index a2ef237e2806..c77d62628f8c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -51,6 +51,7 @@ public class RayConfig { public final String redisModulePath; public final String plasmaStoreExecutablePath; public final String rayletExecutablePath; + public final String driverResourcePath; private void validate() { if (workerMode == WorkerMode.WORKER) { @@ -156,6 +157,18 @@ public RayConfig(Config config) { plasmaStoreExecutablePath = rayHome + "/build/src/plasma/plasma_store_server"; rayletExecutablePath = rayHome + "/build/src/ray/raylet/raylet"; + // driver resource path + String localDriverResourcePath; + if (config.hasPath("ray.driver.resource-path")) { + localDriverResourcePath = config.getString("ray.driver.resource-path"); + } else { + localDriverResourcePath = rayHome + "/driver/resource"; + LOGGER.warn("Didn't configure ray.driver.resource-path, set it to default value: {}", + localDriverResourcePath); + } + + driverResourcePath = localDriverResourcePath; + // validate config validate(); LOGGER.debug("Created config: {}", this); diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index 473a1f033203..cf92b0c21472 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -15,13 +15,18 @@ import org.objectweb.asm.Type; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.runtime.util.JarLoader; import org.ray.runtime.util.LambdaUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Manages functions by driver id. */ public class FunctionManager { + private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class); + static final String CONSTRUCTOR_NAME = ""; /** @@ -36,6 +41,21 @@ public class FunctionManager { */ private Map driverFunctionTables = new HashMap<>(); + /** + * The resource path which we can load the driver's jar resources. + */ + private String driverResourcePath; + + /** + * Construct a FunctionManager with the specified driver resource path. + * + * @param driverResourcePath The specified driver resource that + * can store the driver's resources. + */ + public FunctionManager(String driverResourcePath) { + this.driverResourcePath = driverResourcePath; + } + /** * Get the RayFunction from a RayFunc instance (a lambda). * @@ -66,8 +86,19 @@ public RayFunction getFunction(UniqueId driverId, RayFunc func) { public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDescriptor) { DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId); if (driverFunctionTable == null) { - //TODO(hchen): distinguish class loader by driver id. - ClassLoader classLoader = getClass().getClassLoader(); + String resourcePath = driverResourcePath + "/" + driverId.toString() + "/"; + ClassLoader classLoader; + + try { + classLoader = JarLoader.loadJars(resourcePath, false); + LOGGER.info("Succeeded to load driver({}) resource. Resource path is {}", + driverId, resourcePath); + } catch (Exception e) { + LOGGER.error("Failed to load driver({}) resource. Resource path is {}", + driverId, resourcePath); + classLoader = getClass().getClassLoader(); + } + driverFunctionTable = new DriverFunctionTable(classLoader); driverFunctionTables.put(driverId, driverFunctionTable); } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java b/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java index 8a66923e3464..c6ab5650c038 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/JarLoader.java @@ -14,13 +14,16 @@ import org.apache.commons.io.IOUtils; import org.apache.commons.io.filefilter.DirectoryFileFilter; import org.apache.commons.io.filefilter.RegexFileFilter; -import org.ray.runtime.util.logger.RayLog; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * load and unload jars from a dir. */ public class JarLoader { + private static final Logger LOGGER = LoggerFactory.getLogger(JarLoader.class); + public static URLClassLoader loadJars(String dir, boolean explicitLoad) { // get all jars Collection jars = FileUtils.listFiles( @@ -42,7 +45,7 @@ private static URLClassLoader loadJar(Collection appJars, boolean explicit for (File appJar : appJars) { try { - RayLog.core.info("load jar " + appJar.getAbsolutePath()); + LOGGER.info("succeeded to load jar {}.", appJar.getAbsolutePath()); JarFile jar = new JarFile(appJar.getAbsolutePath()); jars.add(jar); urls.add(appJar.toURI().toURL()); diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index c20d679a9c59..58a3be2de3d6 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -25,9 +25,15 @@ ray { // Available resources on this node, for example "CPU:4,GPU:0". resources: "" - // If worker.mode is DRIVER, specify the driver id. - // If not provided, a random id will be used. - driver.id: "" + // Configuration items about driver. + driver { + // If worker.mode is DRIVER, specify the driver id. + // If not provided, a random id will be used. + id: "" + // If worker.mode is WORKER, it means that worker will load + // the resources from this path to execute tasks. + resource-path: /tmp/ray/driver/resource + } // Root dir of log files. log-dir: /tmp/ray/logs @@ -76,4 +82,5 @@ ray { // RPC socket name of Raylet socket-name: /tmp/ray/sockets/raylet } + } diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 08e4d6415c54..c82ae27af3b2 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -1,5 +1,9 @@ package org.ray.runtime.functionmanager; +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.util.Map; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -41,6 +45,8 @@ public Object bar() { private static FunctionDescriptor barDescriptor; private static FunctionDescriptor barConstructorDescriptor; + private final static String resourcePath = "/tmp/ray/test/resource"; + private FunctionManager functionManager; @BeforeClass @@ -59,7 +65,7 @@ public static void beforeClass() { @Before public void before() { - functionManager = new FunctionManager(); + functionManager = new FunctionManager(FunctionManagerTest.resourcePath); } @Test @@ -116,4 +122,27 @@ public void testLoadFunctionTableForClass() { Assert.assertTrue(res.containsKey( ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor))); } + + //TODO(qwang): This is an integration test case, and we should move it to test folder in the future. + @Test + public void testGetFunctionFromLocalResource() throws Exception{ + UniqueId driverId = UniqueId.fromHexString("0123456789012345678901234567890123456789"); + + //TODO(qwang): We should use a independent app demo instead of `tutorial`. + final String srcJarPath = System.getProperty("user.dir") + + "/../tutorial/target/ray-tutorial-0.1-SNAPSHOT.jar"; + final String destJarPath = resourcePath + "/" + driverId.toString() + + "/ray-tutorial-0.1-SNAPSHOT.jar"; + + File file = new File(resourcePath + "/" + driverId.toString()); + file.mkdirs(); + + Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING); + + FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02", + "sayHello", "()Ljava/lang/String;"); + RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor); + Assert.assertEquals(func.getFunctionDescriptor(), sayHelloDescriptor); + } + } diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index fd47e15ab494..ac7e01124632 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -11,6 +11,7 @@ public class RayConfigTest { @Test public void testCreateRayConfig() { System.setProperty("ray.home", "/path/to/ray"); + System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path"); RayConfig rayConfig = RayConfig.create(); Assert.assertEquals("/path/to/ray", rayConfig.rayHome); @@ -19,8 +20,12 @@ public void testCreateRayConfig() { System.setProperty("ray.home", ""); rayConfig = RayConfig.create(); + Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome); Assert.assertEquals(System.getProperty("user.dir") + "/build/src/common/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); + + Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); + } } From ef1f2fde956fad8df551aa2ff306d73652092a2c Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 9 Oct 2018 04:12:14 +0800 Subject: [PATCH 018/215] Fix the uniqueId toString format. (#3035) --- .../src/main/java/org/ray/api/id/UniqueId.java | 2 +- .../java/org/ray/api/test/UniqueIdTest.java | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/id/UniqueId.java b/java/api/src/main/java/org/ray/api/id/UniqueId.java index 0d32d0f8f3c4..f93bdc737229 100644 --- a/java/api/src/main/java/org/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/org/ray/api/id/UniqueId.java @@ -112,6 +112,6 @@ public boolean equals(Object obj) { @Override public String toString() { - return DatatypeConverter.printHexBinary(id); + return DatatypeConverter.printHexBinary(id).toLowerCase(); } } diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 0a21fc2872bf..95107fc11017 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -9,14 +9,13 @@ import org.ray.api.id.UniqueId; import org.ray.runtime.util.UniqueIdHelper; -@RunWith(MyRunner.class) public class UniqueIdTest { @Test public void testConstructUniqueId() { // Test `fromHexString()` UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", id1.toString()); + Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString()); Assert.assertFalse(id1.isNil()); try { @@ -40,12 +39,12 @@ public void testConstructUniqueId() { ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20); UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer); Assert.assertTrue(Arrays.equals(bytes, id4.getBytes())); - Assert.assertEquals("0123456789ABCDEF0123456789ABCDEF01234567", id4.toString()); + Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString()); // Test `genNil()` UniqueId id6 = UniqueId.genNil(); - Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", id6.toString()); + Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } @@ -55,10 +54,10 @@ public void testComputeReturnId() { UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); UniqueId returnId = UniqueIdHelper.computeReturnId(taskId, 1); - Assert.assertEquals("01000000123456789ABCDEF123456789ABCDEF00", returnId.toString()); + Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); returnId = UniqueIdHelper.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("04030201123456789ABCDEF123456789ABCDEF00", returnId.toString()); + Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); } @Test @@ -66,7 +65,7 @@ public void testComputeTaskId() { UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); UniqueId taskId = UniqueIdHelper.computeTaskId(objId); - Assert.assertEquals("00000000123456789ABCDEF123456789ABCDEF00", taskId.toString()); + Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); } @Test @@ -75,10 +74,10 @@ public void testComputePutId() { UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); UniqueId putId = UniqueIdHelper.computePutId(taskId, 1); - Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00", putId.toString()); + Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); putId = UniqueIdHelper.computePutId(taskId, 0x01020304); - Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00", putId.toString()); + Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); } } From 060891a9c94b81a346f8ecb347132bcc3656924e Mon Sep 17 00:00:00 2001 From: Hanwei Jin Date: Thu, 11 Oct 2018 05:33:15 +0800 Subject: [PATCH 019/215] [cmake] avoid to re-build pyarrow (#2963) * bugfix: env exists check error * support to avoid re-build pyarrow in project * bugfix: adapt gtest for centos lib64 * bugfix: check gtest lib exists in the directory * bugfix: find gtest with checking all libs exists * prefix RAY_ to thirdparty env variables to avoid conflicts with other module * arrow use glog from ray * change the glog and gtest install dir --- cmake/Modules/ArrowExternalProject.cmake | 10 +-- cmake/Modules/BoostExternalProject.cmake | 4 +- .../Modules/FlatBuffersExternalProject.cmake | 9 +-- cmake/Modules/GlogExternalProject.cmake | 6 +- cmake/Modules/GtestExternalProject.cmake | 37 ++++++---- cmake/Modules/ThirdpartyToolchain.cmake | 72 +++++++++++-------- 6 files changed, 76 insertions(+), 62 deletions(-) diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index 827673c35b28..da3f27ba9626 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -23,11 +23,6 @@ set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) set(ARROW_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep) -# The following is needed because in CentOS, the lib directory is named lib64 -if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - set(LIB_SUFFIX 64) -endif() - set(ARROW_INCLUDE_DIR ${ARROW_HOME}/include) set(ARROW_LIBRARY_DIR ${ARROW_HOME}/lib${LIB_SUFFIX}) set(ARROW_SHARED_LIB ${ARROW_LIBRARY_DIR}/libarrow${CMAKE_SHARED_LIBRARY_SUFFIX}) @@ -58,7 +53,8 @@ set(ARROW_CMAKE_ARGS -DARROW_WITH_LZ4=off -DARROW_WITH_ZSTD=off -DFLATBUFFERS_HOME=${FLATBUFFERS_HOME} - -DBOOST_ROOT=${BOOST_ROOT}) + -DBOOST_ROOT=${BOOST_ROOT} + -DGLOG_HOME=${GLOG_HOME}) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # PyArrow needs following settings. @@ -92,7 +88,7 @@ endif() ExternalProject_Add(arrow_ep PREFIX external/arrow - DEPENDS flatbuffers boost + DEPENDS flatbuffers boost glog GIT_REPOSITORY ${arrow_URL} GIT_TAG ${arrow_TAG} ${ARROW_CONFIGURE} diff --git a/cmake/Modules/BoostExternalProject.cmake b/cmake/Modules/BoostExternalProject.cmake index bab016a02b7a..1fbbb0c0b58e 100644 --- a/cmake/Modules/BoostExternalProject.cmake +++ b/cmake/Modules/BoostExternalProject.cmake @@ -9,9 +9,9 @@ # boost is a stable library in ray, and it supports to find # the boost pre-built in environment to speed up build process -if (DEFINED ENV{BOOST_ROOT} AND EXISTS ENV{BOOST_ROOT}) +if (DEFINED ENV{RAY_BOOST_ROOT} AND EXISTS $ENV{RAY_BOOST_ROOT}) set(Boost_USE_STATIC_LIBS ON) - set(BOOST_ROOT "$ENV{BOOST_ROOT}") + set(BOOST_ROOT "$ENV{RAY_BOOST_ROOT}") message(STATUS "Find BOOST_ROOT: ${BOOST_ROOT}") # find_package(Boost COMPONENTS system filesystem REQUIRED) set(Boost_INCLUDE_DIR ${BOOST_ROOT}/include) diff --git a/cmake/Modules/FlatBuffersExternalProject.cmake b/cmake/Modules/FlatBuffersExternalProject.cmake index 57c2216cecfb..508010afced4 100644 --- a/cmake/Modules/FlatBuffersExternalProject.cmake +++ b/cmake/Modules/FlatBuffersExternalProject.cmake @@ -10,13 +10,8 @@ # - FLATBUFFERS_COMPILER # - FBS_DEPENDS, to keep compatible -# The following is needed because in CentOS, the lib directory is named lib64 -if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) - set(LIB_SUFFIX 64) -endif() - -if(DEFINED ENV{FLATBUFFERS_HOME} AND EXISTS ENV{FLATBUFFERS_HOME}) - set(FLATBUFFERS_HOME "$ENV{FLATBUFFERS_HOME}") +if(DEFINED ENV{RAY_FLATBUFFERS_HOME} AND EXISTS $ENV{RAY_FLATBUFFERS_HOME}) + set(FLATBUFFERS_HOME "$ENV{RAY_FLATBUFFERS_HOME}") set(FLATBUFFERS_INCLUDE_DIR "${FLATBUFFERS_HOME}/include") set(FLATBUFFERS_STATIC_LIB "${FLATBUFFERS_HOME}/lib${LIB_SUFFIX}/libflatbuffers.a") set(FLATBUFFERS_COMPILER "${FLATBUFFERS_HOME}/bin/flatc") diff --git a/cmake/Modules/GlogExternalProject.cmake b/cmake/Modules/GlogExternalProject.cmake index 47f11fbdbd6a..2900bae4d523 100644 --- a/cmake/Modules/GlogExternalProject.cmake +++ b/cmake/Modules/GlogExternalProject.cmake @@ -6,8 +6,8 @@ # - GLOG_INCLUDE_DIR # - GLOG_STATIC_LIB -if(DEFINED ENV{GLOG_HOME} AND EXISTS ENV{GLOG_HOME}) - set(GLOG_HOME "$ENV{GLOG_HOME}") +if(DEFINED ENV{RAY_GLOG_HOME} AND EXISTS $ENV{RAY_GLOG_HOME}) + set(GLOG_HOME "$ENV{RAY_GLOG_HOME}") set(GLOG_INCLUDE_DIR "${GLOG_HOME}/include") set(GLOG_STATIC_LIB "${GLOG_HOME}/lib/libglog.a") @@ -23,7 +23,7 @@ else() endif() set(GLOG_URL "https://github.com/google/glog/archive/v${GLOG_VERSION}.tar.gz") - set(GLOG_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/glog/src/glog_ep") + set(GLOG_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/glog-install") set(GLOG_HOME "${GLOG_PREFIX}") set(GLOG_INCLUDE_DIR "${GLOG_PREFIX}/include") set(GLOG_STATIC_LIB "${GLOG_PREFIX}/lib/libglog.a") diff --git a/cmake/Modules/GtestExternalProject.cmake b/cmake/Modules/GtestExternalProject.cmake index 5570066c60fb..66e5a76f1d87 100644 --- a/cmake/Modules/GtestExternalProject.cmake +++ b/cmake/Modules/GtestExternalProject.cmake @@ -7,18 +7,31 @@ # - GTEST_MAIN_STATIC_LIB # - GMOCK_MAIN_STATIC_LIB -if(DEFINED ENV{GTEST_HOME} AND EXISTS ENV{GTEST_HOME}) - set(GTEST_HOME "$ENV{GTEST_HOME}") - set(GTEST_INCLUDE_DIR "${GTEST_HOME}/include") - set(GTEST_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") - set(GTEST_MAIN_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") - set(GMOCK_MAIN_STATIC_LIB - "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX}") +set(GTEST_FOUND FALSE) + +if(DEFINED ENV{RAY_GTEST_HOME} AND EXISTS $ENV{RAY_GTEST_HOME}) + set(GTEST_HOME "$ENV{RAY_GTEST_HOME}") + find_path(GTEST_INCLUDE_DIR NAMES gtest/gtest.h + PATHS ${GTEST_HOME} NO_DEFAULT_PATH + PATH_SUFFIXES "include") + find_library(GTEST_LIBRARIES NAMES gtest gtest_main gmock_main + PATHS ${GTEST_HOME} NO_DEFAULT_PATH + PATH_SUFFIXES "lib") + if(GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) + set(GTEST_FOUND TRUE) + set(GTEST_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GTEST_MAIN_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest_main${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(GMOCK_MAIN_STATIC_LIB + "${GTEST_HOME}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX}") + + add_custom_target(googletest_ep) + endif() + +endif() - add_custom_target(googletest_ep) -else() +if(NOT GTEST_FOUND) set(GTEST_VERSION "1.8.0") if(APPLE) @@ -31,7 +44,7 @@ else() endif() set(GTEST_CMAKE_CXX_FLAGS "${EP_CXX_FLAGS} ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}} ${GTEST_CMAKE_CXX_FLAGS}") - set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/googletest/src/googletest_ep") + set(GTEST_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/external/googletest-install") set(GTEST_INCLUDE_DIR "${GTEST_PREFIX}/include") set(GTEST_STATIC_LIB "${GTEST_PREFIX}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}gtest${CMAKE_STATIC_LIBRARY_SUFFIX}") diff --git a/cmake/Modules/ThirdpartyToolchain.cmake b/cmake/Modules/ThirdpartyToolchain.cmake index 0e0553483ec2..de06b6b7f594 100644 --- a/cmake/Modules/ThirdpartyToolchain.cmake +++ b/cmake/Modules/ThirdpartyToolchain.cmake @@ -4,6 +4,11 @@ # we have to turn it on for dependencies too set(EP_CXX_FLAGS "${EP_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +# The following is needed because in CentOS, the lib directory is named lib64 +if(EXISTS "/etc/redhat-release" AND CMAKE_SIZEOF_VOID_P EQUAL 8) + set(LIB_SUFFIX 64) +endif() + if(RAY_BUILD_TESTS OR RAY_BUILD_BENCHMARKS) add_custom_target(unittest ctest -L unittest) @@ -25,18 +30,16 @@ if(RAY_BUILD_TESTS OR RAY_BUILD_BENCHMARKS) add_dependencies(gmock_main googletest_ep) endif() -if(RAY_USE_GLOG) - include(GlogExternalProject) - message(STATUS "Glog home: ${GLOG_HOME}") - message(STATUS "Glog include dir: ${GLOG_INCLUDE_DIR}") - message(STATUS "Glog static lib: ${GLOG_STATIC_LIB}") +include(GlogExternalProject) +message(STATUS "Glog home: ${GLOG_HOME}") +message(STATUS "Glog include dir: ${GLOG_INCLUDE_DIR}") +message(STATUS "Glog static lib: ${GLOG_STATIC_LIB}") - include_directories(${GLOG_INCLUDE_DIR}) - ADD_THIRDPARTY_LIB(glog - STATIC_LIB ${GLOG_STATIC_LIB}) +include_directories(${GLOG_INCLUDE_DIR}) +ADD_THIRDPARTY_LIB(glog + STATIC_LIB ${GLOG_STATIC_LIB}) - add_dependencies(glog glog_ep) -endif() +add_dependencies(glog glog_ep) # boost include(BoostExternalProject) @@ -95,19 +98,6 @@ ADD_THIRDPARTY_LIB(plasma STATIC_LIB ${PLASMA_STATIC_LIB}) add_dependencies(plasma plasma_ep) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - # pyarrow - find_package(PythonInterp REQUIRED) - message(STATUS "PYTHON_EXECUTABLE for pyarrow: ${PYTHON_EXECUTABLE}") - - set(pyarrow_ENV - "PKG_CONFIG_PATH=${ARROW_LIBRARY_DIR}/pkgconfig" - "PYARROW_WITH_PLASMA=1" - "PYARROW_WITH_TENSORFLOW=1" - "PYARROW_BUNDLE_ARROW_CPP=1" - "PARQUET_HOME=${PARQUET_HOME}" - "PYARROW_WITH_PARQUET=1" - ) - # clean the arrow_ep/python/build/lib.xxxxx directory, # or when you build with another python version, it creates multiple lib.xxxx directories set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES "${ARROW_SOURCE_DIR}/python/build/") @@ -115,13 +105,33 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # here we use externalProject to process pyarrow building # add_custom_command would have problem with setup.py - ExternalProject_Add(pyarrow_ext - PREFIX external/pyarrow - DEPENDS arrow_ep - DOWNLOAD_COMMAND "" - BUILD_IN_SOURCE 1 - CONFIGURE_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build - BUILD_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build_ext - INSTALL_COMMAND bash -c "cp -rf \$(find ${ARROW_SOURCE_DIR}/python/build/ -maxdepth 1 -type d -print | grep -m1 'lib')/pyarrow ${CMAKE_SOURCE_DIR}/python/ray/pyarrow_files/") + if(EXISTS ${ARROW_SOURCE_DIR}/python/build/) + # if we did not run `make clean`, skip the rebuild of pyarrow + add_custom_target(pyarrow_ext) + else() + # pyarrow + find_package(PythonInterp REQUIRED) + message(STATUS "PYTHON_EXECUTABLE for pyarrow: ${PYTHON_EXECUTABLE}") + + # PYARROW_PARALLEL= , so it will add -j to pyarrow build + set(pyarrow_ENV + "PKG_CONFIG_PATH=${ARROW_LIBRARY_DIR}/pkgconfig" + "PYARROW_WITH_PLASMA=1" + "PYARROW_WITH_TENSORFLOW=1" + "PYARROW_BUNDLE_ARROW_CPP=1" + "PARQUET_HOME=${PARQUET_HOME}" + "PYARROW_WITH_PARQUET=1" + "PYARROW_PARALLEL=") + + ExternalProject_Add(pyarrow_ext + PREFIX external/pyarrow + DEPENDS arrow_ep + DOWNLOAD_COMMAND "" + BUILD_IN_SOURCE 1 + CONFIGURE_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build + BUILD_COMMAND cd ${ARROW_SOURCE_DIR}/python && ${CMAKE_COMMAND} -E env ${pyarrow_ENV} ${PYTHON_EXECUTABLE} setup.py build_ext + INSTALL_COMMAND bash -c "cp -rf \$(find ${ARROW_SOURCE_DIR}/python/build/ -maxdepth 1 -type d -print | grep -m1 'lib')/pyarrow ${CMAKE_SOURCE_DIR}/python/ray/pyarrow_files/") + + endif() endif () From 4a2ed47b6c889b32792b2c0bdd83f64174b00144 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Thu, 11 Oct 2018 08:30:23 +0800 Subject: [PATCH 020/215] [Java] Improve some Java code (#3040) This PR improves some java codes, and removes some duplicated code. --- .../org/ray/runtime/AbstractRayRuntime.java | 6 +++--- .../runtime/objectstore/ObjectStoreProxy.java | 19 ++++++------------ .../ray/runtime/raylet/RayletClientImpl.java | 20 ++++++------------- ...{UniqueIdHelper.java => UniqueIdUtil.java} | 19 ++++++++++++++---- .../java/org/ray/api/test/UniqueIdTest.java | 13 ++++++------ 5 files changed, 36 insertions(+), 41 deletions(-) rename java/runtime/src/main/java/org/ray/runtime/util/{UniqueIdHelper.java => UniqueIdUtil.java} (81%) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 407f142d9d16..330dbe365f15 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -22,7 +22,7 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; -import org.ray.runtime.util.UniqueIdHelper; +import org.ray.runtime.util.UniqueIdUtil; import org.ray.runtime.util.exception.TaskExecutionException; import org.ray.runtime.util.logger.RayLog; @@ -63,7 +63,7 @@ public AbstractRayRuntime(RayConfig rayConfig) { @Override public RayObject put(T obj) { - UniqueId objectId = UniqueIdHelper.computePutId( + UniqueId objectId = UniqueIdUtil.computePutId( workerContext.getCurrentTask().taskId, workerContext.nextPutIndex()); put(objectId, obj); @@ -222,7 +222,7 @@ public RayActor createActor(RayFunc actorFactoryFunc, Object[] args) { private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { UniqueId[] ret = new UniqueId[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1); + ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1); } return ret; } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index b497f5c44b14..3a33d862e4b1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -7,6 +7,7 @@ import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; +import org.ray.runtime.util.UniqueIdUtil; import org.ray.runtime.util.exception.TaskExecutionException; /** @@ -15,9 +16,10 @@ */ public class ObjectStoreProxy { + private static final int GET_TIMEOUT_MS = 1000; + private final AbstractRayRuntime runtime; private final ObjectStoreLink store; - private final int getTimeoutMs = 1000; public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { this.runtime = runtime; @@ -26,7 +28,7 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { public Pair get(UniqueId objectId, boolean isMetadata) throws TaskExecutionException { - return get(objectId, getTimeoutMs, isMetadata); + return get(objectId, GET_TIMEOUT_MS, isMetadata); } public Pair get(UniqueId id, int timeoutMs, boolean isMetadata) @@ -46,12 +48,12 @@ public Pair get(UniqueId id, int timeoutMs, boolean isMetadata public List> get(List objectIds, boolean isMetadata) throws TaskExecutionException { - return get(objectIds, getTimeoutMs, isMetadata); + return get(objectIds, GET_TIMEOUT_MS, isMetadata); } public List> get(List ids, int timeoutMs, boolean isMetadata) throws TaskExecutionException { - List objs = store.get(getIdBytes(ids), timeoutMs, isMetadata); + List objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); List> ret = new ArrayList<>(); for (int i = 0; i < objs.size(); i++) { byte[] obj = objs.get(i); @@ -69,15 +71,6 @@ public List> get(List ids, int timeoutMs, boole return ret; } - private static byte[][] getIdBytes(List objectIds) { - int size = objectIds.size(); - byte[][] ids = new byte[size][]; - for (int i = 0; i < size; i++) { - ids[i] = objectIds.get(i).getBytes(); - } - return ids; - } - public void put(UniqueId id, Object obj, Object metadata) { store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata)); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 2152495045f2..b84fe22db0ac 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -18,7 +18,7 @@ import org.ray.runtime.generated.TaskLanguage; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdHelper; +import org.ray.runtime.util.UniqueIdUtil; import org.ray.runtime.util.logger.RayLog; public class RayletClientImpl implements RayletClient { @@ -50,7 +50,8 @@ public WaitResult wait(List> waitFor, int numReturns, int ti ids.add(element.getId()); } - boolean[] ready = nativeWaitObject(client, getIdBytes(ids), numReturns, timeoutMs, false); + boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), + numReturns, timeoutMs, false); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -89,9 +90,9 @@ public TaskSpec getTask() { public void reconstructObjects(List objectIds, boolean fetchOnly) { if (RayLog.core.isInfoEnabled()) { RayLog.core.info("Reconstructing objects for task {}, object IDs are {}", - UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds); + UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); } - nativeReconstructObjects(client, getIdBytes(objectIds), fetchOnly); + nativeReconstructObjects(client, UniqueIdUtil.getIdBytes(objectIds), fetchOnly); } @Override @@ -107,7 +108,7 @@ public void notifyUnblocked() { @Override public void freePlasmaObjects(List objectIds, boolean localOnly) { - byte[][] objectIdsArray = getIdBytes(objectIds); + byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly); } @@ -242,15 +243,6 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { return buffer; } - private static byte[][] getIdBytes(List objectIds) { - int size = objectIds.size(); - byte[][] ids = new byte[size][]; - for (int i = 0; i < size; i++) { - ids[i] = objectIds.get(i).getBytes(); - } - return ids; - } - public void destroy() { nativeDestroy(client); } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java b/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java similarity index 81% rename from java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java rename to java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java index 52d9a7359247..d7b347945792 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdHelper.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java @@ -3,6 +3,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; +import java.util.List; + import org.ray.api.id.UniqueId; @@ -11,7 +13,7 @@ * Note: any changes to these methods must be synced with C++ helper functions * in src/ray/id.h */ -public class UniqueIdHelper { +public class UniqueIdUtil { public static final int OBJECT_INDEX_POS = 0; public static final int OBJECT_INDEX_LENGTH = 4; @@ -37,7 +39,7 @@ private static UniqueId computeObjectId(UniqueId taskId, int index) { System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH); ByteBuffer wbb = ByteBuffer.wrap(objId); wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index); + wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index); return new UniqueId(objId); } @@ -63,9 +65,18 @@ public static UniqueId computePutId(UniqueId taskId, int putIndex) { public static UniqueId computeTaskId(UniqueId objectId) { byte[] taskId = new byte[UniqueId.LENGTH]; System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH); - Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS, - UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0); + Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS, + UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0); return new UniqueId(taskId); } + + public static byte[][] getIdBytes(List objectIds) { + int size = objectIds.size(); + byte[][] ids = new byte[size][]; + for (int i = 0; i < size; i++) { + ids[i] = objectIds.get(i).getBytes(); + } + return ids; + } } diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 95107fc11017..2fd47057d90d 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -5,9 +5,8 @@ import javax.xml.bind.DatatypeConverter; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.ray.api.id.UniqueId; -import org.ray.runtime.util.UniqueIdHelper; +import org.ray.runtime.util.UniqueIdUtil; public class UniqueIdTest { @@ -53,17 +52,17 @@ public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - UniqueId returnId = UniqueIdHelper.computeReturnId(taskId, 1); + UniqueId returnId = UniqueIdUtil.computeReturnId(taskId, 1); Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); - returnId = UniqueIdHelper.computeReturnId(taskId, 0x01020304); + returnId = UniqueIdUtil.computeReturnId(taskId, 0x01020304); Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); } @Test public void testComputeTaskId() { UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); - UniqueId taskId = UniqueIdHelper.computeTaskId(objId); + UniqueId taskId = UniqueIdUtil.computeTaskId(objId); Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); } @@ -73,10 +72,10 @@ public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); - UniqueId putId = UniqueIdHelper.computePutId(taskId, 1); + UniqueId putId = UniqueIdUtil.computePutId(taskId, 1); Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); - putId = UniqueIdHelper.computePutId(taskId, 0x01020304); + putId = UniqueIdUtil.computePutId(taskId, 0x01020304); Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); } From 828fe24b393a32e62ed958fba6d9c79b5510d0d8 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Fri, 12 Oct 2018 00:45:21 +0800 Subject: [PATCH 021/215] [Java] Fix loading driver resources issue. (#3046) ## What do these changes do? Fix the issue how we load driver resources by a specified path. Also this addressed the comments from the related PR [3044](https://github.com/ray-project/ray/pull/3044). ## Related PRs: [#3044](https://github.com/ray-project/ray/pull/3044) and [#3001](https://github.com/ray-project/ray/pull/3001). --- .../org/ray/runtime/config/RayConfig.java | 9 ++------- .../functionmanager/FunctionManager.java | 6 ++---- .../src/main/resources/ray.default.conf | 19 ++++++++++--------- .../functionmanager/FunctionManagerTest.java | 14 ++++---------- 4 files changed, 18 insertions(+), 30 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index c77d62628f8c..d374d25a577f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -158,17 +158,12 @@ public RayConfig(Config config) { rayletExecutablePath = rayHome + "/build/src/ray/raylet/raylet"; // driver resource path - String localDriverResourcePath; if (config.hasPath("ray.driver.resource-path")) { - localDriverResourcePath = config.getString("ray.driver.resource-path"); + driverResourcePath = config.getString("ray.driver.resource-path"); } else { - localDriverResourcePath = rayHome + "/driver/resource"; - LOGGER.warn("Didn't configure ray.driver.resource-path, set it to default value: {}", - localDriverResourcePath); + driverResourcePath = null; } - driverResourcePath = localDriverResourcePath; - // validate config validate(); LOGGER.debug("Created config: {}", this); diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index cf92b0c21472..d7698c22aa7f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -89,13 +89,11 @@ public RayFunction getFunction(UniqueId driverId, FunctionDescriptor functionDes String resourcePath = driverResourcePath + "/" + driverId.toString() + "/"; ClassLoader classLoader; - try { + if (driverResourcePath != null && !driverResourcePath.isEmpty()) { classLoader = JarLoader.loadJars(resourcePath, false); LOGGER.info("Succeeded to load driver({}) resource. Resource path is {}", driverId, resourcePath); - } catch (Exception e) { - LOGGER.error("Failed to load driver({}) resource. Resource path is {}", - driverId, resourcePath); + } else { classLoader = getClass().getClassLoader(); } diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 58a3be2de3d6..892d90c6cc96 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -25,15 +25,16 @@ ray { // Available resources on this node, for example "CPU:4,GPU:0". resources: "" - // Configuration items about driver. - driver { - // If worker.mode is DRIVER, specify the driver id. - // If not provided, a random id will be used. - id: "" - // If worker.mode is WORKER, it means that worker will load - // the resources from this path to execute tasks. - resource-path: /tmp/ray/driver/resource - } + // Configuration items about driver. + driver { + // If worker.mode is DRIVER, specify the driver id. + // If not provided, a random id will be used. + id: "" + // If this config is set, worker will use different paths to loadresources when + // executing tasks from different drivers. E.g. if it's set to '/tm/driver_resources', + // the path for driver 123 will be '/tmp/driver_resources/123'. + resource-path: "" + } // Root dir of log files. log-dir: /tmp/ray/logs diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index c82ae27af3b2..f5ff1e481a36 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -45,10 +45,6 @@ public Object bar() { private static FunctionDescriptor barDescriptor; private static FunctionDescriptor barConstructorDescriptor; - private final static String resourcePath = "/tmp/ray/test/resource"; - - private FunctionManager functionManager; - @BeforeClass public static void beforeClass() { fooFunc = FunctionManagerTest::foo; @@ -63,13 +59,9 @@ public static void beforeClass() { "()V"); } - @Before - public void before() { - functionManager = new FunctionManager(FunctionManagerTest.resourcePath); - } - @Test public void testGetFunctionFromRayFunc() { + final FunctionManager functionManager = new FunctionManager(null); // Test normal function. RayFunction func = functionManager.getFunction(UniqueId.NIL, fooFunc); Assert.assertFalse(func.isConstructor()); @@ -91,6 +83,7 @@ public void testGetFunctionFromRayFunc() { @Test public void testGetFunctionFromFunctionDescriptor() { + final FunctionManager functionManager = new FunctionManager(null); // Test normal function. RayFunction func = functionManager.getFunction(UniqueId.NIL, fooDescriptor); Assert.assertFalse(func.isConstructor()); @@ -129,6 +122,7 @@ public void testGetFunctionFromLocalResource() throws Exception{ UniqueId driverId = UniqueId.fromHexString("0123456789012345678901234567890123456789"); //TODO(qwang): We should use a independent app demo instead of `tutorial`. + final String resourcePath = "/tmp/ray/test/resource"; final String srcJarPath = System.getProperty("user.dir") + "/../tutorial/target/ray-tutorial-0.1-SNAPSHOT.jar"; final String destJarPath = resourcePath + "/" + driverId.toString() + @@ -136,9 +130,9 @@ public void testGetFunctionFromLocalResource() throws Exception{ File file = new File(resourcePath + "/" + driverId.toString()); file.mkdirs(); - Files.copy(Paths.get(srcJarPath), Paths.get(destJarPath), StandardCopyOption.REPLACE_EXISTING); + final FunctionManager functionManager = new FunctionManager(resourcePath); FunctionDescriptor sayHelloDescriptor = new FunctionDescriptor("org.ray.exercise.Exercise02", "sayHello", "()Ljava/lang/String;"); RayFunction func = functionManager.getFunction(driverId, sayHelloDescriptor); From f9b58d7b0252f312e7b0983ffefabf68d3e7a939 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 11 Oct 2018 23:42:13 -0700 Subject: [PATCH 022/215] [tune] Tweaks to Trainable and Verbosity (#2889) --- doc/source/tune-searchalg.rst | 6 +- doc/source/tune-usage.rst | 6 +- docker/examples/Dockerfile | 1 + python/ray/rllib/agents/agent.py | 6 +- .../tune/examples/async_hyperband_example.py | 2 +- python/ray/tune/examples/hyperband_example.py | 2 +- .../tune/examples/mnist_pytorch_trainable.py | 6 +- python/ray/tune/examples/pbt_example.py | 2 +- .../examples/pbt_tune_cifar10_with_keras.py | 2 +- .../tune/examples/tune_mnist_ray_hyperband.py | 6 +- python/ray/tune/function_runner.py | 4 +- python/ray/tune/result.py | 3 +- python/ray/tune/schedulers/hyperband.py | 5 +- python/ray/tune/suggest/hyperopt.py | 5 +- python/ray/tune/test/trial_runner_test.py | 65 +++++++++++++++++ python/ray/tune/trainable.py | 72 +++++++++++++------ python/ray/tune/trial.py | 14 ++-- 17 files changed, 160 insertions(+), 47 deletions(-) diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 97e8ce1bc295..e8e5b0fa672e 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -25,10 +25,13 @@ By default, Tune uses the `default search space and variant generation process < :noindex: +Note that other search algorithms will not necessarily extend this class and may require a different search space declaration than the default Tune format. + HyperOpt Search (Tree-structured Parzen Estimators) --------------------------------------------------- -The ``HyperOptSearch`` is a SearchAlgorithm that is backed by `HyperOpt `__ to perform sequential model-based hyperparameter optimization. +The ``HyperOptSearch`` is a SearchAlgorithm that is backed by `HyperOpt `__ to perform sequential model-based hyperparameter optimization. Note that this class does not extend ``ray.tune.suggest.BasicVariantGenerator``, so you will not be able to use Tune's default variant generation/search space declaration when using HyperOptSearch. + In order to use this search algorithm, you will need to install HyperOpt via the following command: .. code-block:: bash @@ -47,7 +50,6 @@ An example of this can be found in `hyperopt_example.py `__. Sampling Multiple Times diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index d4e6c34b2217..80685b7d3154 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -6,5 +6,6 @@ FROM ray-project/deploy RUN conda install -y numpy RUN apt-get install -y zlib1g-dev RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras +RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 030ae64248d8..8041eba06b88 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -234,10 +234,10 @@ def train(self): return Trainable.train(self) - def _setup(self): + def _setup(self, config): env = self._env_id if env: - self.config["env"] = env + config["env"] = env if _global_registry.contains(ENV_CREATOR, env): self.env_creator = _global_registry.get(ENV_CREATOR, env) else: @@ -248,7 +248,7 @@ def _setup(self): # Merge the supplied config with the class default merged_config = self._default_config.copy() - merged_config = deep_update(merged_config, self.config, + merged_config = deep_update(merged_config, config, self._allow_unknown_configs, self._allow_unknown_subkeys) self.config = merged_config diff --git a/python/ray/tune/examples/async_hyperband_example.py b/python/ray/tune/examples/async_hyperband_example.py index 2c368b4e3d05..e07f11b325a8 100644 --- a/python/ray/tune/examples/async_hyperband_example.py +++ b/python/ray/tune/examples/async_hyperband_example.py @@ -23,7 +23,7 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self): + def _setup(self, config): self.timestep = 0 def _train(self): diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index 94f603e8206c..baf133b411bf 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -23,7 +23,7 @@ class MyTrainableClass(Trainable): maximum reward value reached. """ - def _setup(self): + def _setup(self, config): self.timestep = 0 def _train(self): diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index 0d23c0cc2130..2c0c68bceb8b 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -80,9 +80,9 @@ def forward(self, x): class TrainMNIST(Trainable): - def _setup(self): - args = self.config.pop("args") - vars(args).update(self.config) + def _setup(self, config): + args = config.pop("args") + vars(args).update(config) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index c958d2512e83..3433e82f94ee 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -18,7 +18,7 @@ class MyTrainableClass(Trainable): """Fake agent whose learning rate is determined by dummy factors.""" - def _setup(self): + def _setup(self, config): self.timestep = 0 self.current_value = 0.0 diff --git a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py index 28575f546682..63e3d00e8d1f 100755 --- a/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py +++ b/python/ray/tune/examples/pbt_tune_cifar10_with_keras.py @@ -105,7 +105,7 @@ def _build_model(self, input_shape): model = Model(inputs=x, outputs=y, name="model1") return model - def _setup(self): + def _setup(self, config): self.train_data, self.test_data = self._read_data() x_train = self.train_data[0] model = self._build_model(x_train.shape[1:]) diff --git a/python/ray/tune/examples/tune_mnist_ray_hyperband.py b/python/ray/tune/examples/tune_mnist_ray_hyperband.py index 29939ff24308..9dbc46775232 100755 --- a/python/ray/tune/examples/tune_mnist_ray_hyperband.py +++ b/python/ray/tune/examples/tune_mnist_ray_hyperband.py @@ -128,7 +128,7 @@ def bias_variable(shape): class TrainMNIST(Trainable): """Example MNIST trainable.""" - def _setup(self): + def _setup(self, config): global activation_fn self.timestep = 0 @@ -148,7 +148,7 @@ def _setup(self): self.x = tf.placeholder(tf.float32, [None, 784]) self.y_ = tf.placeholder(tf.float32, [None, 10]) - activation_fn = getattr(tf.nn, self.config['activation']) + activation_fn = getattr(tf.nn, config['activation']) # Build the graph for the deep net y_conv, self.keep_prob = setupCNN(self.x) @@ -160,7 +160,7 @@ def _setup(self): with tf.name_scope('adam_optimizer'): train_step = tf.train.AdamOptimizer( - self.config['learning_rate']).minimize(cross_entropy) + config['learning_rate']).minimize(cross_entropy) self.train_step = train_step diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index d1704b6aa94f..1b93d3b6c300 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -90,10 +90,10 @@ class FunctionRunner(Trainable): _name = "func" _default_config = DEFAULT_CONFIG - def _setup(self): + def _setup(self, config): entrypoint = self._trainable_func() self._status_reporter = StatusReporter() - scrubbed_config = self.config.copy() + scrubbed_config = config.copy() for k in self._default_config: if k in scrubbed_config: del scrubbed_config[k] diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index ec307eaed8fb..5b7ade11fe0e 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -39,7 +39,8 @@ TRAINING_ITERATION = "training_iteration" # Where Tune writes result files by default -DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results") +DEFAULT_RESULTS_DIR = (os.environ.get("TUNE_RESULT_DIR") + or os.path.expanduser("~/ray_results")) # Meta file about status under each experiment directory, can be # parsed by automlboard if exists. diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 7e2f8f27e278..71c69b3063a2 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -50,7 +50,10 @@ class HyperBandScheduler(FIFOScheduler): For example, to limit trials to 10 minutes and early stop based on the `episode_mean_reward` attr, construct: - ``HyperBand('time_total_s', 'episode_reward_mean', 600)`` + ``HyperBand('time_total_s', 'episode_reward_mean', max_t=600)`` + + Note that Tune's stopping criteria will be applied in conjunction with + HyperBand's early stopping mechanisms. See also: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 45fe9753e0ea..9173b56cc372 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -4,7 +4,11 @@ import numpy as np import copy +import logging + try: + hyperopt_logger = logging.getLogger("hyperopt") + hyperopt_logger.setLevel(logging.WARNING) import hyperopt as hpo except Exception as e: hpo = None @@ -47,7 +51,6 @@ class HyperOptSearch(SuggestionAlgorithm): >>> } >>> algo = HyperOptSearch( >>> space, max_concurrent=4, reward_attr="neg_mean_loss") - >>> algo.add_configurations(config) """ def __init__(self, diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 1e4c0509dc15..65b8fbe36f62 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -433,6 +433,71 @@ def train3(config, reporter): self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 5) self.assertEqual(trial3.last_result["timesteps_this_iter"], 0) + def testCheckpointDict(self): + class TestTrain(Trainable): + def _setup(self, config): + self.state = {"hi": 1} + + def _train(self): + return dict(timesteps_this_iter=1, done=True) + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + test_trainable = TestTrain() + result = test_trainable.save() + test_trainable.state["hi"] = 2 + test_trainable.restore(result) + self.assertEqual(test_trainable.state["hi"], 1) + + trials = run_experiments({ + "foo": { + "run": TestTrain, + "checkpoint_at_end": True + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + + def testMultipleCheckpoints(self): + class TestTrain(Trainable): + def _setup(self, config): + self.state = {"hi": 1, "iter": 0} + + def _train(self): + self.state["iter"] += 1 + return dict(timesteps_this_iter=1, done=True) + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + test_trainable = TestTrain() + checkpoint_1 = test_trainable.save() + test_trainable.train() + checkpoint_2 = test_trainable.save() + self.assertNotEqual(checkpoint_1, checkpoint_2) + test_trainable.restore(checkpoint_2) + self.assertEqual(test_trainable.state["iter"], 1) + test_trainable.restore(checkpoint_1) + self.assertEqual(test_trainable.state["iter"], 0) + + trials = run_experiments({ + "foo": { + "run": TestTrain, + "checkpoint_at_end": True + } + }) + for trial in trials: + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertTrue(trial.has_checkpoint()) + class RunExperimentTest(unittest.TestCase): def setUp(self): diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 1e537d26d953..6c8b02cf0afa 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -4,6 +4,7 @@ from datetime import datetime +import copy import gzip import io import logging @@ -83,7 +84,7 @@ def __init__(self, config=None, logger_creator=None): self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False - self._setup() + self._setup(copy.deepcopy(self.config)) self._local_ip = ray.services.get_node_ip_address() @classmethod @@ -143,6 +144,8 @@ def train(self): start = time.time() result = self._train() + assert isinstance(result, dict), "_train() needs to return a dict." + result = result.copy() self._iteration += 1 @@ -211,11 +214,27 @@ def save(self, checkpoint_dir=None): Checkpoint path that may be passed to restore(). """ - checkpoint_path = self._save(checkpoint_dir or self.logdir) - pickle.dump([ - self._experiment_id, self._iteration, self._timesteps_total, - self._time_total, self._episodes_total - ], open(checkpoint_path + ".tune_metadata", "wb")) + checkpoint_path = tempfile.mkdtemp( + prefix="checkpoint_{}".format(self._iteration), + dir=checkpoint_dir or self.logdir) + checkpoint = self._save(checkpoint_path) + saved_as_dict = False + if isinstance(checkpoint, str): + checkpoint_path = checkpoint + elif isinstance(checkpoint, dict): + saved_as_dict = True + pickle.dump(checkpoint, open(checkpoint_path + ".tune_state", + "wb")) + else: + raise ValueError("Return value from `_save` must be dict or str.") + pickle.dump({ + "experiment_id": self._experiment_id, + "iteration": self._iteration, + "timesteps_total": self._timesteps_total, + "time_total": self._time_total, + "episodes_total": self._episodes_total, + "saved_as_dict": saved_as_dict + }, open(checkpoint_path + ".tune_metadata", "wb")) return checkpoint_path def save_to_object(self): @@ -259,13 +278,19 @@ def restore(self, checkpoint_path): This method restores additional metadata saved with the checkpoint. """ - self._restore(checkpoint_path) metadata = pickle.load(open(checkpoint_path + ".tune_metadata", "rb")) - self._experiment_id = metadata[0] - self._iteration = metadata[1] - self._timesteps_total = metadata[2] - self._time_total = metadata[3] - self._episodes_total = metadata[4] + self._experiment_id = metadata["experiment_id"] + self._iteration = metadata["iteration"] + self._timesteps_total = metadata["timesteps_total"] + self._time_total = metadata["time_total"] + self._episodes_total = metadata["episodes_total"] + saved_as_dict = metadata["saved_as_dict"] + if saved_as_dict: + with open(checkpoint_path + ".tune_state", "rb") as loaded_state: + checkpoint_dict = pickle.load(loaded_state) + self._restore(checkpoint_dict) + else: + self._restore(checkpoint_path) self._restored = True def restore_from_object(self, obj): @@ -321,27 +346,34 @@ def _save(self, checkpoint_dir): can be stored. Returns: - Checkpoint path that may be passed to restore(). Typically - would default to `checkpoint_dir`. + checkpoint (str | dict): If string, the return value is + expected to be the checkpoint path that will be passed to + `_restore()`. If dict, the return value will be automatically + serialized by Tune and passed to `_restore()`. + + Examples: + >>> checkpoint_data = trainable._save(checkpoint_dir) + >>> trainable2._restore(checkpoint_data) """ raise NotImplementedError - def _restore(self, checkpoint_path): + def _restore(self, checkpoint): """Subclasses should override this to implement restore(). Args: - checkpoint_path (str): The directory where the checkpoint - is stored. + checkpoint (str | dict): Value as returned by `_save`. + If a string, then it is the checkpoint path. """ raise NotImplementedError - def _setup(self): + def _setup(self, config): """Subclasses should override this for custom initialization. - Subclasses can access the hyperparameter configuration via - ``self.config``. + Args: + config (dict): Hyperparameters and other configs given. + Copy of `self.config`. """ pass diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 98fcbc6d55e6..59559ebbe2c2 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,6 +8,7 @@ import time import tempfile import os +from numbers import Number import ray from ray.tune import TuneError @@ -33,12 +34,14 @@ class Resources( namedtuple("Resources", ["cpu", "gpu", "extra_cpu", "extra_gpu"])): """Ray resources required to schedule a trial. + TODO: Custom resources. + Attributes: - cpu (int): Number of CPUs to allocate to the trial. - gpu (int): Number of GPUs to allocate to the trial. - extra_cpu (int): Extra CPUs to reserve in case the trial needs to + cpu (float): Number of CPUs to allocate to the trial. + gpu (float): Number of GPUs to allocate to the trial. + extra_cpu (float): Extra CPUs to reserve in case the trial needs to launch additional Ray actors that use CPUs. - extra_gpu (int): Extra GPUs to reserve in case the trial needs to + extra_gpu (float): Extra GPUs to reserve in case the trial needs to launch additional Ray actors that use GPUs. """ @@ -46,6 +49,9 @@ class Resources( __slots__ = () def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0): + for entry in [cpu, gpu, extra_cpu, extra_gpu]: + assert isinstance(entry, Number), "Improper resource value." + assert entry >= 0, "Resource cannot be negative." return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu, extra_gpu) From 87639b9e260d2831bdbd0fda2642c3cfea3ccca2 Mon Sep 17 00:00:00 2001 From: Hanwei Jin Date: Sat, 13 Oct 2018 00:03:30 +0800 Subject: [PATCH 023/215] move make clean before cmake command, avoid always running mvn install plasma java lib (#3047) --- build.sh | 5 ++++- cmake/Modules/ArrowExternalProject.cmake | 13 +++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/build.sh b/build.sh index 496bbdddb575..e3e81c72a476 100755 --- a/build.sh +++ b/build.sh @@ -101,12 +101,15 @@ fi pushd "$BUILD_DIR" +# avoid the command failed and exits +# and cmake will check some directories to determine whether some targets built +make clean || true + cmake -DCMAKE_BUILD_TYPE=$CBUILD_TYPE \ -DCMAKE_RAY_LANG_JAVA=$RAY_BUILD_JAVA \ -DCMAKE_RAY_LANG_PYTHON=$RAY_BUILD_PYTHON \ -DRAY_USE_NEW_GCS=$RAY_USE_NEW_GCS \ -DPYTHON_EXECUTABLE:FILEPATH=$PYTHON_EXECUTABLE $ROOT_DIR -make clean make -j${PARALLEL} popd diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index da3f27ba9626..07c48d97f0cc 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -9,6 +9,7 @@ # - ARROW_INCLUDE_DIR # - ARROW_SHARED_LIB # - ARROW_STATIC_LIB +# - ARROW_LIBRARY_DIR # - PLASMA_INCLUDE_DIR # - PLASMA_STATIC_LIB # - PLASMA_SHARED_LIB @@ -95,12 +96,16 @@ ExternalProject_Add(arrow_ep BUILD_BYPRODUCTS "${ARROW_SHARED_LIB}" "${ARROW_STATIC_LIB}") if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - ExternalProject_Add_Step(arrow_ep arrow_ep_install_java_lib - COMMAND bash -c "cd ${ARROW_SOURCE_DIR}/java && mvn clean install -pl plasma -am -Dmaven.test.skip > /dev/null" - DEPENDEES build) + set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES "${ARROW_SOURCE_DIR}/java/target/") + + if(NOT EXISTS ${ARROW_SOURCE_DIR}/java/target/) + ExternalProject_Add_Step(arrow_ep arrow_ep_install_java_lib + COMMAND bash -c "cd ${ARROW_SOURCE_DIR}/java && mvn clean install -pl plasma -am -Dmaven.test.skip > /dev/null" + DEPENDEES build) + endif() # add install of library plasma_java, it is not configured in plasma CMakeLists.txt ExternalProject_Add_Step(arrow_ep arrow_ep_install_plasma_java - COMMAND bash -c "cp ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep-build/release/libplasma_java.* ${ARROW_LIBRARY_DIR}/" + COMMAND bash -c "cp -rf ${CMAKE_CURRENT_BINARY_DIR}/external/arrow/src/arrow_ep-build/release/libplasma_java.* ${ARROW_LIBRARY_DIR}/" DEPENDEES install) endif () From 473ee4eb3fd41b555c6682e5dd946adfe4c6df15 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 13 Oct 2018 00:03:52 -0700 Subject: [PATCH 024/215] [rllib] Add unit test and some better error messages for custom policy states (#3032) --- python/ray/rllib/evaluation/sampler.py | 4 +++ .../ray/rllib/evaluation/tf_policy_graph.py | 17 ++++++++---- python/ray/rllib/test/test_multi_agent_env.py | 26 +++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index ec0f13c4e445..78043b3d4a4b 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -364,6 +364,10 @@ def new_episode(): # Record the policy eval results for policy_id, eval_data in to_eval.items(): actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] + if len(rnn_in_cols[policy_id]) != len(rnn_out_cols): + raise ValueError( + "Length of RNN in did not match RNN out, got: " + "{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols)) # Add RNN state info for f_i, column in enumerate(rnn_in_cols[policy_id]): pi_info_cols["state_in_{}".format(f_i)] = column diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index e9119c87527b..09a84981ee83 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -95,11 +95,18 @@ def __init__(self, self._variables = ray.experimental.TensorFlowVariables( self._loss, self._sess) - assert len(self._state_inputs) == len(self._state_outputs) == \ - len(self.get_initial_state()), \ - (self._state_inputs, self._state_outputs, self.get_initial_state()) - if self._state_inputs: - assert self._seq_lens is not None + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") def build_compute_actions(self, builder, diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 96eaabaf1dff..493b338cf8e7 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -15,6 +15,7 @@ from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \ MockPolicyGraph from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -306,6 +307,31 @@ def testMultiAgentSampleRoundRobin(self): self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) + def testCustomRNNStateValues(self): + h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} + + class StatefulPolicyGraph(PolicyGraph): + def compute_actions(self, + obs_batch, + state_batches, + is_training=False, + episodes=None): + return [0] * len(obs_batch), [[h] * len(obs_batch)], {} + + def get_initial_state(self): + return [{}] # empty dict + + ev = PolicyEvaluator( + env_creator=lambda _: gym.make("CartPole-v0"), + policy_graph=StatefulPolicyGraph, + batch_steps=5) + batch = ev.sample() + self.assertEqual(batch.count, 5) + self.assertEqual(batch["state_in_0"][0], {}) + self.assertEqual(batch["state_out_0"][0], h) + self.assertEqual(batch["state_in_0"][1], h) + self.assertEqual(batch["state_out_0"][1], h) + def testReturningModelBasedRolloutsData(self): class ModelBasedPolicyGraph(PGPolicyGraph): def compute_actions(self, From 866c7a574c632ff0f2f63cde0a6343bf89cbeea1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 13 Oct 2018 19:50:23 -0700 Subject: [PATCH 025/215] [rllib] Don't crash printing out error message (#3054) * fix er * update --- python/ray/rllib/agents/agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 8041eba06b88..28b423417063 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -3,7 +3,6 @@ from __future__ import print_function import copy -import json import os import pickle import tempfile @@ -175,7 +174,7 @@ def resource_help(cls, config): return ("\n\nYou can adjust the resource requests of RLlib agents by " "setting `num_workers` and other configs. See the " "DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: " + json.dumps(config)) + "The config of this agent is: {}".format(config)) def __init__(self, config=None, env=None, logger_creator=None): """Initialize an RLLib agent. From 4dc78b735bf40df91825febbe3a8ea9fc535ab4d Mon Sep 17 00:00:00 2001 From: Marlon Date: Mon, 15 Oct 2018 07:25:39 +0200 Subject: [PATCH 026/215] [tune] Fix misleading comment (#3058) --- python/ray/tune/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 5b7ade11fe0e..c75ba77e101b 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -25,7 +25,7 @@ # Number of timesteps in this iteration. TIMESTEPS_THIS_ITER = "timesteps_this_iter" -# (Optional/Auto-filled) Accumulated time in seconds for this experiment. +# (Auto-filled) Accumulated number of timesteps for this entire experiment. TIMESTEPS_TOTAL = "timesteps_total" # (Auto-filled) Time in seconds this iteration took to run. From 3c891c6ecec973a305e753fd8dcf8d374291e828 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Oct 2018 11:02:50 -0700 Subject: [PATCH 027/215] [rllib] Parallel-data loading and multi-gpu support for IMPALA (#2766) --- doc/source/rllib-algorithms.rst | 2 +- python/ray/rllib/agents/impala/impala.py | 24 +- .../agents/impala/vtrace_policy_graph.py | 81 +++++-- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 35 ++- python/ray/rllib/examples/cartpole_lstm.py | 20 +- .../optimizers/async_replay_optimizer.py | 10 +- .../optimizers/async_samples_optimizer.py | 229 ++++++++++++++++-- python/ray/rllib/optimizers/multi_gpu_impl.py | 87 ++++--- .../rllib/optimizers/multi_gpu_optimizer.py | 4 +- python/ray/rllib/optimizers/replay_buffer.py | 12 +- .../ray/rllib/test/test_supported_spaces.py | 2 +- test/jenkins_tests/run_multi_node_tests.sh | 23 +- 12 files changed, 417 insertions(+), 112 deletions(-) diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index d764fc7ad8ea..be39b61bb500 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -43,7 +43,7 @@ Importance Weighted Actor-Learner Architecture (IMPALA) `[paper] `__ `[implementation] `__ -In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. +In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib's IMPALA implementation uses DeepMind's reference `V-trace code `__. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a `custom model `__. Multiple learner GPUs and experience replay are also supported. Tuned examples: `PongNoFrameskip-v4 `__, `vectorized configuration `__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 `__ diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index cfa55bd735c8..1ad2b673f429 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -11,8 +11,16 @@ from ray.tune.trial import Resources OPTIMIZER_SHARED_CONFIGS = [ + "lr", + "num_envs_per_worker", + "num_gpus", "sample_batch_size", "train_batch_size", + "replay_buffer_num_slots", + "replay_proportion", + "num_parallel_data_loaders", + "grad_clip", + "max_sample_requests_in_flight_per_worker", ] DEFAULT_CONFIG = with_common_config({ @@ -25,10 +33,22 @@ "sample_batch_size": 50, "train_batch_size": 500, "min_iter_time_s": 10, - "gpu": True, "num_workers": 2, "num_cpus_per_worker": 1, "num_gpus_per_worker": 0, + # number of GPUs the learner should use. + "num_gpus": 1, + # set >1 to load data into GPUs in parallel. Increases GPU memory usage + # proportionally with the number of loaders. + "num_parallel_data_loaders": 1, + # level of queuing for sampling. + "max_sample_requests_in_flight_per_worker": 2, + # set >0 to enable experience replay. Saved samples will be replayed with + # a p:1 proportion to new data samples. + "replay_proportion": 0.0, + # number of sample batches to store for replay. The number of transitions + # saved total will be (replay_buffer_num_slots * sample_batch_size). + "replay_buffer_num_slots": 100, # Learning params. "grad_clip": 40.0, @@ -65,7 +85,7 @@ def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, - gpu=cf["gpu"] and cf["gpu_fraction"] or 0, + gpu=cf["num_gpus"] and cf["num_gpus"] * cf["gpu_fraction"] or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index f6984687166c..2ba018b91d56 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -31,6 +31,7 @@ def __init__(self, rewards, values, bootstrap_value, + valid_mask, vf_loss_coeff=0.5, entropy_coeff=-0.01, clip_rho_threshold=1.0, @@ -52,6 +53,7 @@ def __init__(self, rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + valid_mask: A bool tensor of valid RNN input elements (#2992). """ # Compute vtrace on the CPU for better perf. @@ -70,14 +72,16 @@ def __init__(self, # The policy gradients loss self.pi_loss = -tf.reduce_sum( - actions_logp * self.vtrace_returns.pg_advantages) + tf.boolean_mask(actions_logp * self.vtrace_returns.pg_advantages, + valid_mask)) # The baseline loss - delta = values - self.vtrace_returns.vs + delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta)) # The entropy loss - self.entropy = tf.reduce_sum(actions_entropy) + self.entropy = tf.reduce_sum( + tf.boolean_mask(actions_entropy, valid_mask)) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff + @@ -85,20 +89,49 @@ def __init__(self, class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): - def __init__(self, observation_space, action_space, config): + def __init__(self, + observation_space, + action_space, + config, + existing_inputs=None): config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) assert config["batch_mode"] == "truncate_episodes", \ "Must use `truncate_episodes` batch mode with V-trace." self.config = config self.sess = tf.get_default_session() + # Create input placeholders + if existing_inputs: + actions, dones, behaviour_logits, rewards, observations = \ + existing_inputs[:5] + existing_state_in = existing_inputs[5:-1] + existing_seq_lens = existing_inputs[-1] + else: + if isinstance(action_space, gym.spaces.Discrete): + ac_size = action_space.n + actions = tf.placeholder(tf.int64, [None], name="ac") + else: + raise UnsupportedSpaceException( + "Action space {} is not supported for IMPALA.".format( + action_space)) + dones = tf.placeholder(tf.bool, [None], name="dones") + rewards = tf.placeholder(tf.float32, [None], name="rewards") + behaviour_logits = tf.placeholder( + tf.float32, [None, ac_size], name="behaviour_logits") + observations = tf.placeholder( + tf.float32, [None] + list(observation_space.shape)) + existing_state_in = None + existing_seq_lens = None + # Setup the policy - self.observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) - self.model = ModelCatalog.get_model(self.observations, logit_dim, - self.config["model"]) + self.model = ModelCatalog.get_model( + observations, + logit_dim, + self.config["model"], + state_in=existing_state_in, + seq_lens=existing_seq_lens) action_dist = dist_class(self.model.outputs) values = tf.reshape( linear(self.model.last_layer, 1, "value", normc_initializer(1.0)), @@ -106,19 +139,6 @@ def __init__(self, observation_space, action_space, config): self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - # Setup the policy loss - if isinstance(action_space, gym.spaces.Discrete): - ac_size = action_space.n - actions = tf.placeholder(tf.int64, [None], name="ac") - else: - raise UnsupportedSpaceException( - "Action space {} is not supported for IMPALA.".format( - action_space)) - dones = tf.placeholder(tf.bool, [None], name="dones") - rewards = tf.placeholder(tf.float32, [None], name="rewards") - behaviour_logits = tf.placeholder( - tf.float32, [None, ac_size], name="behaviour_logits") - def to_batches(tensor): if self.config["model"]["use_lstm"]: B = tf.shape(self.model.seq_lens)[0] @@ -135,6 +155,13 @@ def to_batches(tensor): rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + if self.model.state_in: + max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 + mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( actions=to_batches(actions)[:-1], @@ -147,6 +174,7 @@ def to_batches(tensor): rewards=to_batches(rewards)[:-1], values=to_batches(values)[:-1], bootstrap_value=to_batches(values)[-1], + valid_mask=to_batches(mask)[:-1], vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], @@ -158,7 +186,7 @@ def to_batches(tensor): ("dones", dones), ("behaviour_logits", behaviour_logits), ("rewards", rewards), - ("obs", self.observations), + ("obs", observations), ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) @@ -167,7 +195,7 @@ def to_batches(tensor): observation_space, action_space, self.sess, - obs_input=self.observations, + obs_input=observations, action_sampler=action_dist.sample(), loss=self.loss.total_loss, loss_inputs=loss_in, @@ -218,3 +246,10 @@ def postprocess_trajectory(self, sample_batch, other_agent_batches=None): def get_initial_state(self): return self.model.state_init + + def copy(self, existing_inputs): + return VTracePolicyGraph( + self.observation_space, + self.action_space, + self.config, + existing_inputs=existing_inputs) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 9456ebe944cc..638ae0eb858b 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -24,6 +24,7 @@ def __init__(self, curr_action_dist, value_fn, cur_kl_coeff, + valid_mask, entropy_coeff=0, clip_param=0.1, vf_clip_param=0.1, @@ -48,28 +49,33 @@ def __init__(self, value_fn (Tensor): Current value function output Tensor. cur_kl_coeff (Variable): Variable holding the current PPO KL coefficient. + valid_mask (Tensor): A bool mask of valid input elements (#2992). entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter vf_clip_param (float): Clip parameter for the value function vf_loss_coeff (float): Coefficient of the value function loss use_gae (bool): If true, use the Generalized Advantage Estimator. """ + + def reduce_mean_valid(t): + return tf.reduce_mean(tf.boolean_mask(t, valid_mask)) + dist_cls, _ = ModelCatalog.get_action_dist(action_space, {}) prev_dist = dist_cls(logits) # Make loss functions. logp_ratio = tf.exp( curr_action_dist.logp(actions) - prev_dist.logp(actions)) action_kl = prev_dist.kl(curr_action_dist) - self.mean_kl = tf.reduce_mean(action_kl) + self.mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() - self.mean_entropy = tf.reduce_mean(curr_entropy) + self.mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param)) - self.mean_policy_loss = tf.reduce_mean(-surrogate_loss) + self.mean_policy_loss = reduce_mean_valid(-surrogate_loss) if use_gae: vf_loss1 = tf.square(value_fn - value_targets) @@ -77,14 +83,15 @@ def __init__(self, value_fn - vf_preds, -vf_clip_param, vf_clip_param) vf_loss2 = tf.square(vf_clipped - value_targets) vf_loss = tf.maximum(vf_loss1, vf_loss2) - self.mean_vf_loss = tf.reduce_mean(vf_loss) - loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl + - vf_loss_coeff * vf_loss - - entropy_coeff * curr_entropy) + self.mean_vf_loss = reduce_mean_valid(vf_loss) + loss = reduce_mean_valid( + -surrogate_loss + cur_kl_coeff * action_kl + + vf_loss_coeff * vf_loss - entropy_coeff * curr_entropy) else: self.mean_vf_loss = tf.constant(0.0) - loss = tf.reduce_mean(-surrogate_loss + cur_kl_coeff * action_kl - - entropy_coeff * curr_entropy) + loss = reduce_mean_valid(-surrogate_loss + + cur_kl_coeff * action_kl - + entropy_coeff * curr_entropy) self.loss = loss @@ -179,6 +186,13 @@ def __init__(self, else: self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) + if self.model.state_in: + max_seq_len = tf.reduce_max(self.model.seq_lens) + mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(adv_ph) + self.loss_obj = PPOLoss( action_space, value_targets_ph, @@ -189,6 +203,7 @@ def __init__(self, curr_action_dist, self.value_function, self.kl_coeff, + mask, entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], @@ -227,7 +242,7 @@ def __init__(self, def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" return PPOPolicyGraph( - None, + self.observation_space, self.action_space, self.config, existing_inputs=existing_inputs) diff --git a/python/ray/rllib/examples/cartpole_lstm.py b/python/ray/rllib/examples/cartpole_lstm.py index e3d0ddc4c570..67fd35d28dcf 100644 --- a/python/ray/rllib/examples/cartpole_lstm.py +++ b/python/ray/rllib/examples/cartpole_lstm.py @@ -14,6 +14,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=200) +parser.add_argument("--run", type=str, default="PPO") class CartPoleStatelessEnv(gym.Env): @@ -163,18 +164,29 @@ def close(self): tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv()) ray.init() + + configs = { + "PPO": { + "num_sgd_iter": 5, + }, + "IMPALA": { + "num_workers": 2, + "num_gpus": 0, + "vf_loss_coeff": 0.01, + }, + } + tune.run_experiments({ "test": { "env": "cartpole_stateless", - "run": "PPO", + "run": args.run, "stop": { "episode_reward_mean": args.stop }, - "config": { - "num_sgd_iter": 5, + "config": dict(configs[args.run], **{ "model": { "use_lstm": True, }, - }, + }), } }) diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index 3ed5f37d390f..c48fd1860197 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -87,14 +87,14 @@ def update_priorities(self, batch_indexes, td_errors): new_priorities = (np.abs(td_errors) + self.prioritized_replay_eps) self.replay_buffer.update_priorities(batch_indexes, new_priorities) - def stats(self): + def stats(self, debug=False): stat = { "add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3), "replay_time_ms": round(1000 * self.replay_timer.mean, 3), "update_priorities_time_ms": round( 1000 * self.update_priorities_timer.mean, 3), } - stat.update(self.replay_buffer.stats()) + stat.update(self.replay_buffer.stats(debug=debug)) return stat @@ -274,7 +274,7 @@ def _step(self): return sample_timesteps, train_timesteps def stats(self): - replay_stats = ray.get(self.replay_actors[0].stats.remote()) + replay_stats = ray.get(self.replay_actors[0].stats.remote(self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers @@ -288,13 +288,13 @@ def stats(self): 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, + "learner_queue": self.learner.learner_queue_size.stats(), + "replay_shard_0": replay_stats, } debug_stats = { - "replay_shard_0": replay_stats, "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, - "learner_queue": self.learner.learner_queue_size.stats(), } if self.debug: stats.update(debug_stats) diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 3b6bb861b482..69f5e849b542 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -6,19 +6,22 @@ from __future__ import division from __future__ import print_function +import numpy as np +import random import time import threading from six.moves import queue import ray +from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.actors import TaskPool from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat -SAMPLE_QUEUE_DEPTH = 2 LEARNER_QUEUE_MAX_SIZE = 16 +NUM_DATA_LOAD_THREADS = 16 class LearnerThread(threading.Thread): @@ -38,8 +41,10 @@ def __init__(self, local_evaluator): self.outqueue = queue.Queue() self.queue_timer = TimerStat() self.grad_timer = TimerStat() + self.load_timer = TimerStat() + self.load_wait_timer = TimerStat() self.daemon = True - self.weights_updated = 0 + self.weights_updated = False self.stats = {} def run(self): @@ -48,18 +53,129 @@ def run(self): def step(self): with self.queue_timer: - ra, batch = self.inqueue.get() - - if batch is not None: - with self.grad_timer: - fetches = self.local_evaluator.compute_apply(batch) - self.weights_updated += 1 - if "stats" in fetches: - self.stats = fetches["stats"] - self.outqueue.put(batch.count) + batch = self.inqueue.get() + + with self.grad_timer: + fetches = self.local_evaluator.compute_apply(batch) + self.weights_updated = True + self.stats = fetches.get("stats", {}) + + self.outqueue.put(batch.count) + self.learner_queue_size.push(self.inqueue.qsize()) + + +class TFMultiGPULearner(LearnerThread): + """Learner that can use multiple GPUs and parallel loading.""" + + def __init__(self, + local_evaluator, + num_gpus=1, + lr=0.0005, + train_batch_size=500, + grad_clip=40, + num_parallel_data_loaders=1): + # Multi-GPU requires TensorFlow to function. + import tensorflow as tf + + LearnerThread.__init__(self, local_evaluator) + self.lr = lr + self.train_batch_size = train_batch_size + if not num_gpus: + self.devices = ["/cpu:0"] + else: + self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)] + print("TFMultiGPULearner devices", self.devices) + assert self.train_batch_size % len(self.devices) == 0 + assert self.train_batch_size >= len(self.devices), "batch too small" + self.policy = self.local_evaluator.policy_map["default"] + + # per-GPU graph copies created below must share vars with the policy + # reuse is set to AUTO_REUSE because Adam nodes are created after + # all of the device copies are created. + self.par_opt = [] + with self.local_evaluator.tf_sess.graph.as_default(): + with self.local_evaluator.tf_sess.as_default(): + with tf.variable_scope("default", reuse=tf.AUTO_REUSE): + if self.policy._state_inputs: + rnn_inputs = self.policy._state_inputs + [ + self.policy._seq_lens + ] + else: + rnn_inputs = [] + adam = tf.train.AdamOptimizer(self.lr) + for _ in range(num_parallel_data_loaders): + self.par_opt.append( + LocalSyncParallelOptimizer( + adam, + self.devices, + [v for _, v in self.policy.loss_inputs()], + rnn_inputs, + 999999, # it will get rounded down + self.policy.copy, + grad_norm_clipping=grad_clip)) + + self.sess = self.local_evaluator.tf_sess + self.sess.run(tf.global_variables_initializer()) + + self.idle_optimizers = queue.Queue() + self.ready_optimizers = queue.Queue() + for opt in self.par_opt: + self.idle_optimizers.put(opt) + for i in range(NUM_DATA_LOAD_THREADS): + self.loader_thread = _LoaderThread(self, share_stats=(i == 0)) + self.loader_thread.start() + + def step(self): + assert self.loader_thread.is_alive() + with self.load_wait_timer: + opt = self.ready_optimizers.get() + + with self.grad_timer: + fetches = opt.optimize(self.sess, 0) + self.weights_updated = True + self.stats = fetches.get("stats", {}) + + self.idle_optimizers.put(opt) + self.outqueue.put(self.train_batch_size) self.learner_queue_size.push(self.inqueue.qsize()) +class _LoaderThread(threading.Thread): + def __init__(self, learner, share_stats): + threading.Thread.__init__(self) + self.learner = learner + self.daemon = True + if share_stats: + self.queue_timer = learner.queue_timer + self.load_timer = learner.load_timer + else: + self.queue_timer = TimerStat() + self.load_timer = TimerStat() + + def run(self): + while True: + self.step() + + def step(self): + s = self.learner + with self.queue_timer: + batch = s.inqueue.get() + + opt = s.idle_optimizers.get() + + with self.load_timer: + tuples = s.policy._get_loss_inputs_dict(batch) + data_keys = [ph for _, ph in s.policy.loss_inputs()] + if s.policy._state_inputs: + state_keys = s.policy._state_inputs + [s.policy._seq_lens] + else: + state_keys = [] + opt.load_data(s.sess, [tuples[k] for k in data_keys], + [tuples[k] for k in state_keys]) + + s.ready_optimizers.put(opt) + + class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. @@ -67,13 +183,38 @@ class AsyncSamplesOptimizer(PolicyOptimizer): and remote evaluators (IMPALA actors). """ - def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): - - self.debug = debug + def _init(self, + train_batch_size=500, + sample_batch_size=50, + num_envs_per_worker=1, + num_gpus=0, + lr=0.0005, + grad_clip=40, + replay_buffer_num_slots=0, + replay_proportion=0.0, + num_parallel_data_loaders=1, + max_sample_requests_in_flight_per_worker=2): self.learning_started = False self.train_batch_size = train_batch_size + self.sample_batch_size = sample_batch_size - self.learner = LearnerThread(self.local_evaluator) + if num_gpus > 1 or num_parallel_data_loaders > 1: + print( + "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( + num_gpus, num_parallel_data_loaders)) + if train_batch_size // max(1, num_gpus) % ( + sample_batch_size // num_envs_per_worker) != 0: + raise ValueError( + "Sample batches must evenly divide across GPUs.") + self.learner = TFMultiGPULearner( + self.local_evaluator, + lr=lr, + num_gpus=num_gpus, + train_batch_size=train_batch_size, + grad_clip=grad_clip, + num_parallel_data_loaders=num_parallel_data_loaders) + else: + self.learner = LearnerThread(self.local_evaluator) self.learner.start() assert len(self.remote_evaluators) > 0 @@ -85,6 +226,7 @@ def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): ["put_weights", "enqueue", "sample_processing", "train", "sample"] } self.num_weight_syncs = 0 + self.num_replayed = 0 self.learning_started = False # Kick off async background sampling @@ -92,11 +234,19 @@ def _init(self, train_batch_size=512, sample_batch_size=50, debug=False): weights = self.local_evaluator.get_weights() for ev in self.remote_evaluators: ev.set_weights.remote(weights) - for _ in range(SAMPLE_QUEUE_DEPTH): + for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) self.batch_buffer = [] + if replay_proportion: + assert replay_buffer_num_slots > 0 + assert (replay_buffer_num_slots * sample_batch_size > + train_batch_size) + self.replay_proportion = replay_proportion + self.replay_buffer_num_slots = replay_buffer_num_slots + self.replay_batches = [] + def step(self): assert self.learner.is_alive() start = time.time() @@ -112,23 +262,52 @@ def step(self): self.num_steps_sampled += sample_timesteps self.num_steps_trained += train_timesteps + def _augment_with_replay(self, sample_futures): + def can_replay(): + num_needed = int( + np.ceil(self.train_batch_size / self.sample_batch_size)) + return len(self.replay_batches) > num_needed + + for ev, sample_batch in sample_futures: + sample_batch = ray.get(sample_batch) + yield ev, sample_batch + + if can_replay(): + f = self.replay_proportion + while random.random() < f: + f -= 1 + replay_batch = random.choice(self.replay_batches) + self.num_replayed += replay_batch.count + yield None, replay_batch + def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: - for ev, sample_batch in self.sample_tasks.completed_prefetch(): - sample_batch = ray.get(sample_batch) - sample_timesteps += sample_batch.count + for ev, sample_batch in self._augment_with_replay( + self.sample_tasks.completed_prefetch()): self.batch_buffer.append(sample_batch) if sum(b.count for b in self.batch_buffer) >= self.train_batch_size: train_batch = self.batch_buffer[0].concat_samples( self.batch_buffer) with self.timers["enqueue"]: - self.learner.inqueue.put((ev, train_batch)) + self.learner.inqueue.put(train_batch) self.batch_buffer = [] + # If the batch was replayed, skip the update below. + if ev is None: + continue + + sample_timesteps += sample_batch.count + + # Put in replay buffer if enabled + if self.replay_buffer_num_slots > 0: + self.replay_batches.append(sample_batch) + if len(self.replay_batches) > self.replay_buffer_num_slots: + self.replay_batches.pop(0) + # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: @@ -154,6 +333,10 @@ def stats(self): } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) + timing["learner_load_time_ms"] = round( + 1000 * self.learner.load_timer.mean, 3) + timing["learner_load_wait_time_ms"] = round( + 1000 * self.learner.load_wait_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { @@ -161,14 +344,10 @@ def stats(self): 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, - } - debug_stats = { + "num_steps_replayed": self.num_replayed, "timing_breakdown": timing, - "pending_sample_tasks": self.sample_tasks.count, "learner_queue": self.learner.learner_queue_size.stats(), } - if self.debug: - stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats) diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 7233e37e9380..1affe8df395e 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -36,13 +36,13 @@ class LocalSyncParallelOptimizer(object): to define the per-device loss ops. rnn_inputs: Extra input placeholders for RNN inputs. These will have shape [BATCH_SIZE // MAX_SEQ_LEN, ...]. - per_device_batch_size: Number of tuples to optimize over at a time per - device. In each call to `optimize()`, + max_per_device_batch_size: Number of tuples to optimize over at a time + per device. In each call to `optimize()`, `len(devices) * per_device_batch_size` tuples of data will be - processed. + processed. If this is larger than the total data size, it will be + clipped. build_graph: Function that takes the specified inputs and returns a TF Policy Graph instance. - logdir: Directory to place debugging output in. grad_norm_clipping: None or int stdev to clip grad norms by """ @@ -51,18 +51,14 @@ def __init__(self, devices, input_placeholders, rnn_inputs, - per_device_batch_size, + max_per_device_batch_size, build_graph, - logdir, grad_norm_clipping=None): - # TODO(rliaw): remove logdir self.optimizer = optimizer self.devices = devices - self.batch_size = per_device_batch_size * len(devices) - self.per_device_batch_size = per_device_batch_size + self.max_per_device_batch_size = max_per_device_batch_size self.loss_inputs = input_placeholders + rnn_inputs self.build_graph = build_graph - self.logdir = logdir # First initialize the shared loss network with tf.name_scope(TOWER_SCOPE_NAME): @@ -71,6 +67,11 @@ def __init__(self, # Then setup the per-device loss graphs that use the shared weights self._batch_index = tf.placeholder(tf.int32, name="batch_index") + # Dynamic batch size, which may be shrunk if there isn't enough data + self._per_device_batch_size = tf.placeholder( + tf.int32, name="per_device_batch_size") + self._loaded_per_device_batch_size = max_per_device_batch_size + # When loading RNN input, we dynamically determine the max seq len self._max_seq_len = tf.placeholder(tf.int32, name="max_seq_len") self._loaded_max_seq_len = 1 @@ -88,9 +89,12 @@ def __init__(self, avg = average_gradients([t.grads for t in self._towers]) if grad_norm_clipping: + clipped = [] + for grad, _ in avg: + clipped.append(grad) + clipped, _ = tf.clip_by_global_norm(clipped, grad_norm_clipping) for i, (grad, var) in enumerate(avg): - if grad is not None: - avg[i] = (tf.clip_by_norm(grad, grad_norm_clipping), var) + avg[i] = (clipped[i], var) self._train_op = self.optimizer.apply_gradients(avg) def load_data(self, sess, inputs, state_inputs): @@ -117,44 +121,64 @@ def load_data(self, sess, inputs, state_inputs): assert len(self.loss_inputs) == len(inputs + state_inputs), \ (self.loss_inputs, inputs, state_inputs) - # The RNN truncation case is more complicated + # Let's suppose we have the following input data, and 2 devices: + # 1 2 3 4 5 6 7 <- state inputs shape + # A A A B B B C C C D D D E E E F F F G G G <- inputs shape + # The data is truncated and split across devices as follows: + # |---| seq len = 3 + # |---------------------------------| seq batch size = 6 seqs + # |----------------| per device batch size = 9 tuples + if len(state_inputs) > 0: + smallest_array = state_inputs[0] seq_len = len(inputs[0]) // len(state_inputs[0]) self._loaded_max_seq_len = seq_len - assert len(state_inputs[0]) * seq_len == len(inputs[0]) - # Make sure the shorter state inputs arrays are evenly divisible + else: + smallest_array = inputs[0] + self._loaded_max_seq_len = 1 + + seq_batch_size = (self.max_per_device_batch_size // + self._loaded_max_seq_len * len(self.devices)) + if len(smallest_array) < seq_batch_size: + # Dynamically shrink the batch size if insufficient data + seq_batch_size = make_divisible_by( + len(smallest_array), len(self.devices)) + if seq_batch_size < len(self.devices): + raise ValueError("Must load at least 1 tuple sequence per device, " + "got only {} total.".format(len(smallest_array))) + self._loaded_per_device_batch_size = ( + seq_batch_size // len(self.devices) * self._loaded_max_seq_len) + + if len(state_inputs) > 0: + # First truncate the RNN state arrays to the seq_batch_size state_inputs = [ - make_divisible_by(arr, self.batch_size) for arr in state_inputs + make_divisible_by(arr, seq_batch_size) for arr in state_inputs ] # Then truncate the data inputs to match inputs = [arr[:len(state_inputs[0]) * seq_len] for arr in inputs] - assert len(state_inputs[0]) * seq_len == len(inputs[0]) - assert len(state_inputs[0]) % self.batch_size == 0 + assert len(state_inputs[0]) * seq_len == len(inputs[0]), \ + (len(state_inputs[0]), seq_batch_size, seq_len, len(inputs[0])) for ph, arr in zip(self.loss_inputs, inputs + state_inputs): feed_dict[ph] = arr truncated_len = len(inputs[0]) else: for ph, arr in zip(self.loss_inputs, inputs + state_inputs): - truncated_arr = make_divisible_by(arr, self.batch_size) + truncated_arr = make_divisible_by(arr, seq_batch_size) feed_dict[ph] = truncated_arr truncated_len = len(truncated_arr) sess.run([t.init_op for t in self._towers], feed_dict=feed_dict) tuples_per_device = truncated_len / len(self.devices) - assert tuples_per_device > 0, \ - "Too few tuples per batch, trying increasing the training " \ - "batch size or decreasing the sgd batch size. Tried to split up " \ - "{} rows {}-ways in batches of {} (total across devices).".format( - len(arr), len(self.devices), self.batch_size) - assert tuples_per_device % self.per_device_batch_size == 0 + assert tuples_per_device > 0, "No data loaded?" + assert tuples_per_device % self._loaded_per_device_batch_size == 0 return tuples_per_device def optimize(self, sess, batch_index): """Run a single step of SGD. Runs a SGD step over a slice of the preloaded batch with size given by - self.per_device_batch_size and offset given by the batch_index + self._loaded_per_device_batch_size and offset given by the batch_index argument. Updates shared model weights based on the averaged per-device @@ -164,13 +188,14 @@ def optimize(self, sess, batch_index): sess: TensorFlow session. batch_index: Offset into the preloaded data. This value must be between `0` and `tuples_per_device`. The amount of data to - process is always fixed to `per_device_batch_size`. + process is at most `max_per_device_batch_size`. Returns: The outputs of extra_ops evaluated over the batch. """ feed_dict = { self._batch_index: batch_index, + self._per_device_batch_size: self._loaded_per_device_batch_size, self._max_seq_len: self._loaded_max_seq_len, } for tower in self._towers: @@ -213,7 +238,7 @@ def _setup_device(self, device, device_input_placeholders, num_data_in): current_batch, ([self._batch_index // scale * granularity] + [0] * len(ph.shape[1:])), - ([self.per_device_batch_size // scale * granularity] + + ([self._per_device_batch_size // scale * granularity] + [-1] * len(ph.shape[1:]))) current_slice.set_shape(ph.shape) device_input_slices.append(current_slice) @@ -229,8 +254,10 @@ def _setup_device(self, device, device_input_placeholders, num_data_in): Tower = namedtuple("Tower", ["init_op", "grads", "loss_graph"]) -def make_divisible_by(array, n): - return array[0:array.shape[0] - array.shape[0] % n] +def make_divisible_by(a, n): + if type(a) is int: + return a - a % n + return a[0:a.shape[0] - a.shape[0] % n] def average_gradients(tower_grads): diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index e47457036393..4595415a1eb0 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -4,7 +4,6 @@ import numpy as np from collections import defaultdict -import os import tensorflow as tf import ray @@ -81,8 +80,7 @@ def _init(self, self.par_opt = LocalSyncParallelOptimizer( self.policy.optimizer(), self.devices, [v for _, v in self.policy.loss_inputs()], rnn_inputs, - self.per_device_batch_size, self.policy.copy, - os.getcwd()) + self.per_device_batch_size, self.policy.copy) self.sess = self.local_evaluator.tf_sess self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/optimizers/replay_buffer.py b/python/ray/rllib/optimizers/replay_buffer.py index 77d954345668..cd5ec732848e 100644 --- a/python/ray/rllib/optimizers/replay_buffer.py +++ b/python/ray/rllib/optimizers/replay_buffer.py @@ -93,14 +93,15 @@ def sample(self, batch_size): self._num_sampled += batch_size return self._encode_sample(idxes) - def stats(self): + def stats(self, debug=False): data = { "added_count": self._num_added, "sampled_count": self._num_sampled, "est_size_bytes": self._est_size_bytes, "num_entries": len(self._storage), } - data.update(self._evicted_hit_stats.stats()) + if debug: + data.update(self._evicted_hit_stats.stats()) return data @@ -233,7 +234,8 @@ def update_priorities(self, idxes, priorities): self._max_priority = max(self._max_priority, priority) - def stats(self): - parent = ReplayBuffer.stats(self) - parent.update(self._prio_change_stats.stats()) + def stats(self, debug=False): + parent = ReplayBuffer.stats(self, debug) + if debug: + parent.update(self._prio_change_stats.stats()) return parent diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 2ced3402a78a..16bdd485f9a9 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -94,7 +94,7 @@ class ModelSupportedSpaces(unittest.TestCase): def testAll(self): ray.init() stats = {} - check_support("IMPALA", {"gpu": False}, stats) + check_support("IMPALA", {"num_gpus": 0}, stats) check_support("DDPG", {"timesteps_per_iteration": 1}, stats) check_support("DQN", {"timesteps_per_iteration": 1}, stats) check_support("A3C", { diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 43815f470cff..0a6db03b90cd 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -189,14 +189,28 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ - --config '{"gpu": false, "num_workers": 2, "min_iter_time_s": 1}' + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1}' docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ - --config '{"gpu": false, "num_workers": 2, "min_iter_time_s": 1, "model": {"use_lstm": true}}' + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "model": {"use_lstm": true}}' + +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run IMPALA \ + --stop '{"training_iteration": 2}' \ + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0}' + +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run IMPALA \ + --stop '{"training_iteration": 2}' \ + --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0, "model": {"use_lstm": true}}' docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ @@ -295,7 +309,10 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ - python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=200 + python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=PPO --stop=200 + +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=IMPALA --stop=100 docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 From 6240ccbc6edeafd2e3223be1edb1448f3ea5d0ca Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Oct 2018 13:42:56 -0700 Subject: [PATCH 028/215] [rllib] Add more warnings when multi-agent envs might not be set up right (#3061) --- python/ray/rllib/env/async_vector_env.py | 5 +++++ python/ray/rllib/evaluation/policy_evaluator.py | 14 +++++++++++++- python/ray/rllib/evaluation/sampler.py | 2 +- python/ray/tune/suggest/variant_generator.py | 5 +++++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/env/async_vector_env.py b/python/ray/rllib/env/async_vector_env.py index c2e5ab1d3086..71876df61191 100644 --- a/python/ray/rllib/env/async_vector_env.py +++ b/python/ray/rllib/env/async_vector_env.py @@ -268,12 +268,17 @@ def send_actions(self, action_dict): raise ValueError("Env {} is already done".format(env_id)) env = self.envs[env_id] obs, rewards, dones, infos = env.step(agent_dict) + assert isinstance(obs, dict), "Not a multi-agent obs" + assert isinstance(rewards, dict), "Not a multi-agent reward" + assert isinstance(dones, dict), "Not a multi-agent return" + assert isinstance(infos, dict), "Not a multi-agent info" if dones["__all__"]: self.dones.add(env_id) self.env_states[env_id].observe(obs, rewards, dones, infos) def try_reset(self, env_id): obs = self.env_states[env_id].reset() + assert isinstance(obs, dict), "Not a multi-agent obs" if obs is not None and env_id in self.dones: self.dones.remove(env_id) return obs diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 1152aab82b49..ac276ee3a846 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -168,6 +168,11 @@ def __init__(self, model_config = model_config or {} policy_mapping_fn = (policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) + if not callable(policy_mapping_fn): + raise ValueError( + "Policy mapping function not callable. If you're using Tune, " + "make sure to escape the function with tune.function() " + "to prevent it from being evaluated as an expression.") self.env_creator = env_creator self.sample_batch_size = batch_steps * num_envs self.batch_mode = batch_mode @@ -230,7 +235,14 @@ def make_env(vector_index): self.policy_map = self._build_policy_map(policy_dict, policy_config) - self.multiagent = self.policy_map.keys() != {DEFAULT_POLICY_ID} + self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} + if self.multiagent: + if not (isinstance(self.env, MultiAgentEnv) + or isinstance(self.env, AsyncVectorEnv)): + raise ValueError( + "Have multiple policy graphs {}, but the env ".format( + self.policy_map) + + "{} is not a subclass of MultiAgentEnv?".format(self.env)) self.filters = { policy_id: get_filter(observation_filter, diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 78043b3d4a4b..64999d3638ee 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -218,7 +218,7 @@ def _env_runner(async_vector_env, horizon = ( async_vector_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: - print("Warning, no horizon specified, assuming infinite") + print("*** WARNING ***: no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 98b830754093..c33e7925167d 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -155,6 +155,11 @@ def _resolve_lambda_vars(spec, lambda_vars): value = fn(_UnresolvedAccessGuard(spec)) except RecursiveDependencyError as e: error = e + except Exception: + raise ValueError( + "Failed to evaluate expression: {}: {}".format(path, fn) + + ". If you meant to pass this as a function literal, use " + "tune.function() to escape it.") else: _assign_value(spec, path, value) resolved[path] = value From 4d8cfc0bf50b5830cfab04ad5c247e85ffba9ed8 Mon Sep 17 00:00:00 2001 From: Praveen Palanisamy <4770482+praveen-palanisamy@users.noreply.github.com> Date: Tue, 16 Oct 2018 14:07:53 -0400 Subject: [PATCH 029/215] [tune] Fix (some more) misleading comments in tune/results.py (#3068) ## What do these changes do? Fix the misleading comments in code for: - `EPISODES_THIS_ITER` - `EPISODES_TOTAL` Had noted it before and planned to fix it along with some other changes but seemed very relevant to stay next to #3058 so sending this now. --- python/ray/tune/result.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index c75ba77e101b..a30124536b5e 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -16,10 +16,10 @@ # (Auto-filled) The pid of the training process. PID = "pid" -# Number of timesteps in this iteration. +# Number of episodes in this iteration. EPISODES_THIS_ITER = "episodes_this_iter" -# (Optional/Auto-filled) Accumulated time in seconds for this experiment. +# (Optional/Auto-filled) Accumulated number of episodes for this experiment. EPISODES_TOTAL = "episodes_total" # Number of timesteps in this iteration. @@ -35,7 +35,7 @@ # (Auto-filled) Accumulated time in seconds for this entire experiment. TIME_TOTAL_S = "time_total_s" -# (Auto-filled) The index of thistraining iteration. +# (Auto-filled) The index of this training iteration. TRAINING_ITERATION = "training_iteration" # Where Tune writes result files by default From 64e5eb305e40903035566f5ce3e40e67edf61475 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Wed, 17 Oct 2018 06:03:18 +0800 Subject: [PATCH 030/215] [Java] Add jvm-parameters in Config. (#3065) --- .../src/main/java/org/ray/runtime/config/RayConfig.java | 8 ++++++++ .../src/main/java/org/ray/runtime/runner/RunManager.java | 6 ++++-- java/runtime/src/main/resources/ray.default.conf | 3 +++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index d374d25a577f..b172a41114e0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -6,6 +6,7 @@ import com.typesafe.config.Config; import com.typesafe.config.ConfigException; import com.typesafe.config.ConfigFactory; + import java.util.List; import java.util.Map; import org.ray.api.id.UniqueId; @@ -35,6 +36,7 @@ public class RayConfig { public final boolean redirectOutput; public final List libraryPath; public final List classpath; + public final List jvmParameters; private String redisAddress; private String redisIp; @@ -127,6 +129,12 @@ public RayConfig(Config config) { List customLibraryPath = config.getStringList("ray.library.path"); // custom classpath classpath = config.getStringList("ray.classpath"); + // custom worker jvm parameters + if (config.hasPath("ray.worker.jvm-parameters")) { + jvmParameters = config.getStringList("ray.worker.jvm-parameters"); + } else { + jvmParameters = ImmutableList.of(); + } // redis configurations String redisAddress = config.getString("ray.redis.address"); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 03429c963a3a..3be219dca75c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -205,8 +205,8 @@ private String buildWorkerCommandRaylet() { // Generate classpath based on current classpath + user-defined classpath. String classpath = concatPath(Stream.concat( - Stream.of(System.getProperty("java.class.path").split(":")), - rayConfig.classpath.stream() + rayConfig.classpath.stream(), + Stream.of(System.getProperty("java.class.path").split(":")) )); cmd.add(classpath); @@ -227,6 +227,8 @@ private String buildWorkerCommandRaylet() { // Config overwrite cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress()); + cmd.addAll(rayConfig.jvmParameters); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 892d90c6cc96..b45d7dc6376d 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -43,6 +43,9 @@ ray { // Otherwise, output will be printed to console. redirect-output: true + // Custom worker jvm parameters. + worker.jvm-parameters: [] + // Custom `java.library.path` // Note, do not use `dir1:dir2` format, put each dir as a list item. library.path: [] From a9e454f6fdaa4dca2cb27e748805953f8721c6f4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 16 Oct 2018 15:55:11 -0700 Subject: [PATCH 031/215] [rllib] Include config dicts in the sphinx docs (#3064) --- doc/source/rllib-algorithms.rst | 63 +++++++++++++ doc/source/rllib-models.rst | 9 ++ doc/source/rllib-training.rst | 51 +++++++---- doc/source/rllib.rst | 1 + python/ray/rllib/agents/a3c/a3c.py | 26 +----- python/ray/rllib/agents/agent.py | 16 ++-- python/ray/rllib/agents/ars/ars.py | 3 +- python/ray/rllib/agents/ddpg/apex.py | 2 +- python/ray/rllib/agents/ddpg/ddpg.py | 3 + python/ray/rllib/agents/dqn/apex.py | 5 +- .../ray/rllib/agents/dqn/common/wrappers.py | 2 +- python/ray/rllib/agents/dqn/dqn.py | 3 + python/ray/rllib/agents/es/es.py | 8 +- python/ray/rllib/agents/impala/impala.py | 10 +-- python/ray/rllib/agents/pg/pg.py | 10 +-- python/ray/rllib/agents/ppo/ppo.py | 10 +-- .../ray/rllib/evaluation/policy_evaluator.py | 7 +- python/ray/rllib/models/__init__.py | 14 ++- python/ray/rllib/models/catalog.py | 88 ++++++++++++------- python/ray/rllib/models/fcnet.py | 4 +- python/ray/rllib/models/lstm.py | 2 +- python/ray/rllib/models/model.py | 2 +- python/ray/rllib/models/preprocessors.py | 18 ++-- python/ray/rllib/models/pytorch/visionnet.py | 4 +- python/ray/rllib/models/visionnet.py | 4 +- .../ray/rllib/test/test_supported_spaces.py | 4 - 26 files changed, 234 insertions(+), 135 deletions(-) diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index be39b61bb500..9a7c535395b8 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -38,6 +38,13 @@ SpaceInvaders 646 ~300 Ape-X using 32 workers in RLlib vs vanilla DQN (orange) and A3C (blue) on PongNoFrameskip-v4. +**Ape-X specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/dqn/apex.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Importance Weighted Actor-Learner Architecture (IMPALA) ------------------------------------------------------- @@ -73,6 +80,13 @@ SpaceInvaders 843 ~300 IMPALA solves Atari several times faster than A2C / A3C, with similar sample efficiency. Here IMPALA scales from 16 to 128 workers to solve PongNoFrameskip-v4 in ~8 minutes. +**IMPALA-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/impala/impala.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Gradient-based ~~~~~~~~~~~~~~ @@ -97,6 +111,13 @@ Qbert 3620 ~1000 SpaceInvaders 692 ~600 ============= ======================== ============================== +**A3C-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/a3c/a3c.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Deep Deterministic Policy Gradients (DDPG) ------------------------------------------ `[paper] `__ `[implementation] `__ @@ -104,6 +125,13 @@ DDPG is implemented similarly to DQN (below). The algorithm can be scaled by inc Tuned examples: `Pendulum-v0 `__, `MountainCarContinuous-v0 `__, `HalfCheetah-v2 `__ +**DDPG-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/ddpg/ddpg.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Deep Q Networks (DQN, Rainbow) ------------------------------ `[paper] `__ `[implementation] `__ @@ -125,12 +153,26 @@ Qbert 3921 7968 15780 SpaceInvaders 650 1001 1025 ~500 ============= ======================== ============================= ============================== =============================== +**DQN-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/dqn/dqn.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Policy Gradients ---------------- `[paper] `__ `[implementation] `__ We include a vanilla policy gradients implementation as an example algorithm. This is usually outperformed by PPO. Tuned examples: `CartPole-v0 `__ +**PG-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/pg/pg.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Proximal Policy Optimization (PPO) ---------------------------------- `[paper] `__ `[implementation] `__ @@ -158,6 +200,13 @@ SpaceInvaders 671 944 ~800 RLlib's multi-GPU PPO scales to multiple GPUs and hundreds of CPUs on solving the Humanoid-v1 task. Here we compare against a reference MPI-based implementation. +**PPO-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/ppo/ppo.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Derivative-free ~~~~~~~~~~~~~~~ @@ -168,6 +217,13 @@ ARS is a random search method for training linear policies for continuous contro Tuned examples: `CartPole-v0 `__, `Swimmer-v2 `__ +**ARS-specific configs** (see also `common configs `__): + +.. literalinclude:: ../../python/ray/rllib/agents/ars/ars.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + Evolution Strategies -------------------- `[paper] `__ `[implementation] `__ @@ -181,3 +237,10 @@ Tuned examples: `Humanoid-v1 `__): + +.. literalinclude:: ../../python/ray/rllib/agents/es/es.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 5b3f88cf0e36..6efc4abb8e7a 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -17,6 +17,15 @@ In addition, if you set ``"model": {"use_lstm": true}``, then the model output w For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple observations flattened (there isn't native tuple support yet, but you can reshape the flattened observation in a custom model). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. +Built-in Model Parameters +~~~~~~~~~~~~~~~~~~~~~~~~~ + +The following is a list of the built-in model hyperparameters: + +.. literalinclude:: ../../python/ray/rllib/models/catalog.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ Custom Models ------------- diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 6d3a142db154..583a0ccae368 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -37,6 +37,29 @@ with ``--env`` (any OpenAI gym environment including ones registered by the user can be used) and for choosing the algorithm with ``--run`` (available options are ``PPO``, ``PG``, ``A2C``, ``A3C``, ``IMPALA``, ``ES``, ``DDPG``, ``DQN``, ``APEX``, and ``APEX_DDPG``). +Evaluating Trained Agents +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to save checkpoints from which to evaluate agents, +set ``--checkpoint-freq`` (number of training iterations between checkpoints) +when running ``train.py``. + + +An example of evaluating a previously trained DQN agent is as follows: + +.. code-block:: bash + + python ray/python/ray/rllib/rollout.py \ + ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1 \ + --run DQN --env CartPole-v0 + +The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint +located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1`` +and renders its behavior in the environment specified by ``--env``. + +Configuration +------------- + Specifying Parameters ~~~~~~~~~~~~~~~~~~~~~ @@ -55,27 +78,17 @@ In an example below, we train A2C by specifying 8 workers through the config fla Specifying Resources ~~~~~~~~~~~~~~~~~~~~ -You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. Many agents also provide a ``num_gpus`` or ``gpu`` option. In addition, you can allocate a fraction of a GPU by setting ``gpu_fraction: f``. For example, with DQN you can pack five agents onto one GPU by setting ``gpu_fraction: 0.2``. Note that fractional GPU support requires enabling the experimental Xray backend by setting the environment variable ``RAY_USE_XRAY=1``. - -Evaluating Trained Agents -~~~~~~~~~~~~~~~~~~~~~~~~~ - -In order to save checkpoints from which to evaluate agents, -set ``--checkpoint-freq`` (number of training iterations between checkpoints) -when running ``train.py``. +You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most agents. Many agents also provide a ``num_gpus`` or ``gpu`` option. In addition, you can allocate a fraction of a GPU by setting ``gpu_fraction: f``. For example, with DQN you can pack five agents onto one GPU by setting ``gpu_fraction: 0.2``. Note that fractional GPU support requires enabling the experimental X-ray backend by setting the environment variable ``RAY_USE_XRAY=1``. +Common Parameters +~~~~~~~~~~~~~~~~~ -An example of evaluating a previously trained DQN agent is as follows: +The following is a list of the common agent hyperparameters: -.. code-block:: bash - - python ray/python/ray/rllib/rollout.py \ - ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1 \ - --run DQN --env CartPole-v0 - -The ``rollout.py`` helper script reconstructs a DQN agent from the checkpoint -located at ``~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint-1`` -and renders its behavior in the environment specified by ``--env``. +.. literalinclude:: ../../python/ray/rllib/agents/agent.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ Tuned Examples ~~~~~~~~~~~~~~ @@ -154,7 +167,7 @@ Tune will schedule the trials to run in parallel on your Ray cluster: == Status == Using FIFO scheduling algorithm. Resources requested: 4/4 CPUs, 0/0 GPUs - Result logdir: /home/eric/ray_results/my_experiment + Result logdir: ~/ray_results/my_experiment PENDING trials: - PPO_CartPole-v0_2_sgd_stepsize=0.0001: PENDING RUNNING trials: diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index ba011d08c45e..8979e702aec6 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -27,6 +27,7 @@ You might also want to clone the Ray repo for convenient access to RLlib helper Training APIs ------------- * `Command-line `__ +* `Configuration `__ * `Python API `__ * `REST API `__ diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index afda9506248d..55f179adecd5 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -10,6 +10,7 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # Size of rollout batch "sample_batch_size": 10, @@ -34,31 +35,10 @@ # Workers sample async. Note that this increases the effective # sample_batch_size by up to 5x due to async buffering of batches. "sample_async": True, - # Model and preprocessor options - "model": { - # Use LSTM model. Requires TF. - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - # (Image statespace) - Converts image to Channels = 1 - "grayscale": True, - # (Image statespace) - Each pixel - "zero_mean": False, - # (Image statespace) - Converts image to (dim, dim, C) - "dim": 84, - # (Image statespace) - Converts image shape to (C, dim, dim) - "channel_major": False, - }, - # Configure TF for single-process operation - "tf_session_args": { - "intra_op_parallelism_threads": 1, - "inter_op_parallelism_threads": 1, - "gpu_options": { - "allow_growth": True, - }, - }, }) +# __sphinx_doc_end__ + class A3CAgent(Agent): """A3C implementations in TensorFlow and PyTorch.""" diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 28b423417063..01defdbc328c 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -10,6 +10,7 @@ import tensorflow as tf import ray +from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils import FilterManager, deep_update, merge_dicts @@ -18,10 +19,11 @@ from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR +# __sphinx_doc_begin__ COMMON_CONFIG = { # Discount factor of the MDP "gamma": 0.99, - # Number of steps after which the rollout gets cut + # Number of steps after which the episode is forced to terminate "horizon": None, # Number of environments to evaluate vectorwise per worker. "num_envs_per_worker": 1, @@ -36,7 +38,7 @@ "batch_mode": "truncate_episodes", # Whether to use a background thread for sampling (slightly off-policy) "sample_async": False, - # Which observation filter to apply to the observation + # Element-wise observation filter, either "NoFilter" or "MeanStdFilter" "observation_filter": "NoFilter", # Whether to synchronize the statistics of remote filters. "synchronize_filters": True, @@ -50,14 +52,12 @@ # Environment name can also be passed via config "env": None, # Arguments to pass to model - "model": { - "use_lstm": False, - "max_seq_len": 20, - }, - # Arguments to pass to the rllib optimizer + "model": MODEL_DEFAULTS, + # Arguments to pass to the policy optimizer. These vary by optimizer. "optimizer": {}, # Configure TF for single-process operation by default "tf_session_args": { + # note: parallelism_threads is set to auto for the local evaluator "intra_op_parallelism_threads": 1, "inter_op_parallelism_threads": 1, "gpu_options": { @@ -88,6 +88,8 @@ }, } +# __sphinx_doc_end__ + def with_common_config(extra_config): """Returns the given config dict merged with common agent confs.""" diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 5984e2e01882..27c6fe87928a 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -24,6 +24,7 @@ "eval_returns", "eval_lengths" ]) +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ "noise_stdev": 0.02, # std deviation of parameter noise "num_rollouts": 32, # number of perturbs to try @@ -34,9 +35,9 @@ "noise_size": 250000000, "eval_prob": 0.03, # probability of evaluating the parameter rewards "report_length": 10, # how many of the last rewards we average over - "env_config": {}, "offset": 0, }) +# __sphinx_doc_end__ @ray.remote diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index c2276d0a9a55..c9053ca8a00a 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -7,7 +7,7 @@ from ray.tune.trial import Resources APEX_DDPG_DEFAULT_CONFIG = merge_dicts( - DDPG_CONFIG, + DDPG_CONFIG, # see also the options in ddpg.py, which are also supported { "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index b475e297a247..c35fdaa71d17 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -13,6 +13,7 @@ "train_batch_size", "learning_starts" ] +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Model === # Hidden layer sizes of the policy network @@ -108,6 +109,8 @@ "min_iter_time_s": 1, }) +# __sphinx_doc_end__ + class DDPGAgent(DQNAgent): """DDPG implementation in TensorFlow.""" diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index e6058b41f9af..052d0fd3e957 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -6,8 +6,9 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( - DQN_CONFIG, + DQN_CONFIG, # see also the options in dqn.py, which are also supported { "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( @@ -31,6 +32,8 @@ }, ) +# __sphinx_doc_end__ + class ApexAgent(DQNAgent): """DQN variant that uses the Ape-X distributed policy optimizer. diff --git a/python/ray/rllib/agents/dqn/common/wrappers.py b/python/ray/rllib/agents/dqn/common/wrappers.py index eb6a6c0d5b5c..97f839abfddd 100644 --- a/python/ray/rllib/agents/dqn/common/wrappers.py +++ b/python/ray/rllib/agents/dqn/common/wrappers.py @@ -13,7 +13,7 @@ def wrap_dqn(env, options): # Override atari default to use the deepmind wrappers. # TODO(ekl) this logic should be pushed to the catalog. - if is_atari and "custom_preprocessor" not in options: + if is_atari and not options.get("custom_preprocessor"): return wrap_deepmind(env, dim=options.get("dim", 84)) return ModelCatalog.get_preprocessor_as_wrapper(env, options) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 25320fd6a982..f86b286ce5a1 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -20,6 +20,7 @@ "learning_starts" ] +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Model === # Number of atoms for representing the distribution of return. When @@ -116,6 +117,8 @@ "min_iter_time_s": 1, }) +# __sphinx_doc_end__ + class DQNAgent(Agent): """DQN implementation in TensorFlow.""" diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 392f98f1d8f2..d5526c0da37f 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -24,6 +24,7 @@ "eval_returns", "eval_lengths" ]) +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ "l2_coeff": 0.005, "noise_stdev": 0.02, @@ -36,10 +37,8 @@ "observation_filter": "MeanStdFilter", "noise_size": 250000000, "report_length": 10, - "env": None, - "env_config": {}, - "model": {}, }) +# __sphinx_doc_end__ @ray.remote @@ -77,7 +76,8 @@ def __init__(self, self.env = env_creator(config["env_config"]) from ray.rllib import models - self.preprocessor = models.ModelCatalog.get_preprocessor(self.env) + self.preprocessor = models.ModelCatalog.get_preprocessor( + self.env, config["model"]) self.sess = utils.make_session(single_threaded=True) self.policy = policies.GenericPolicy( diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 1ad2b673f429..a303643f55ee 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -23,6 +23,7 @@ "max_sample_requests_in_flight_per_worker", ] +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # V-trace params (see vtrace.py). "vtrace": True, @@ -63,15 +64,10 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": -0.01, - - # Model and preprocessor options. - "model": { - "use_lstm": False, - "max_seq_len": 20, - "dim": 84, - }, }) +# __sphinx_doc_end__ + class ImpalaAgent(Agent): """IMPALA implementation using DeepMind's V-trace.""" diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index e1766e7744f2..edc24ca1b05b 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -8,20 +8,16 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # No remote workers by default "num_workers": 0, # Learning rate "lr": 0.0004, - # Override model config - "model": { - # Use LSTM model. - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - }, }) +# __sphinx_doc_end__ + class PGAgent(Agent): """Simple policy gradient agent. diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index d2a991929dfa..ea09dfe59448 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -8,6 +8,7 @@ from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.tune.trial import Resources +# __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # If true, use the Generalized Advantage Estimator (GAE) # with a value function, see https://arxiv.org/pdf/1506.02438.pdf. @@ -53,15 +54,10 @@ "observation_filter": "MeanStdFilter", # Use the sync samples optimizer instead of the multi-gpu one "simple_optimizer": False, - # Override model config - "model": { - # Whether to use LSTM model - "use_lstm": False, - # Max seq length for LSTM training. - "max_seq_len": 20, - }, }) +# __sphinx_doc_end__ + class PPOAgent(Agent): """Multi-GPU optimized implementation of PPO in TensorFlow.""" diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index ac276ee3a846..548b65806bad 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -187,7 +187,7 @@ def __init__(self, def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ - "custom_preprocessor" not in model_config and \ + not model_config.get("custom_preprocessor") and \ preprocessor_pref == "deepmind": if clip_rewards is None: @@ -196,9 +196,8 @@ def wrap(env): def wrap(env): env = wrap_deepmind( env, - dim=model_config.get("dim", 84), - framestack=not model_config.get("use_lstm") - and not model_config.get("no_framestack")) + dim=model_config.get("dim"), + framestack=model_config.get("framestack")) if monitor_path: env = _monitor(env, monitor_path) return env diff --git a/python/ray/rllib/models/__init__.py b/python/ray/rllib/models/__init__.py index ddfdd16b8ba1..52e47e807b3f 100644 --- a/python/ray/rllib/models/__init__.py +++ b/python/ray/rllib/models/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS from ray.rllib.models.action_dist import (ActionDistribution, Categorical, DiagGaussian, Deterministic) from ray.rllib.models.model import Model @@ -7,6 +7,14 @@ from ray.rllib.models.lstm import LSTM __all__ = [ - "ActionDistribution", "Categorical", "DiagGaussian", "Deterministic", - "ModelCatalog", "Model", "Preprocessor", "FullyConnectedNetwork", "LSTM" + "ActionDistribution", + "Categorical", + "DiagGaussian", + "Deterministic", + "ModelCatalog", + "Model", + "Preprocessor", + "FullyConnectedNetwork", + "LSTM", + "MODEL_DEFAULTS", ] diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 370429c43f3c..d2038f55f888 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -18,29 +18,52 @@ from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM -MODEL_CONFIGS = [ +# __sphinx_doc_begin__ +MODEL_DEFAULTS = { # === Built-in options === # Filter config. List of [out_channels, kernel, stride] for each filter - "conv_filters", - "conv_activation", # Nonlinearity for built-in convnet - "fcnet_activation", # Nonlinearity for fully connected net (tanh, relu) - "fcnet_hiddens", # Number of hidden layers for fully connected net - "dim", # Dimension for ATARI - "grayscale", # Converts ATARI frame to 1 Channel Grayscale image - "zero_mean", # Changes frame to range from [-1, 1] if true - "extra_frameskip", # (int) for number of frames to skip - "free_log_std", # Documented in ray.rllib.models.Model - "channel_major", # Pytorch conv requires images to be channel-major - "squash_to_range", # Whether to squash the action output to space range - "use_lstm", # Whether to wrap the model with a LSTM - "max_seq_len", # Max seq len for training the LSTM, defaults to 20 - "lstm_cell_size", # Size of the LSTM cell + "conv_filters": None, + # Nonlinearity for built-in convnet + "conv_activation": "relu", + # Nonlinearity for fully connected net (tanh, relu) + "fcnet_activation": "tanh", + # Number of hidden layers for fully connected net + "fcnet_hiddens": [256, 256], + # For control envs, documented in ray.rllib.models.Model + "free_log_std": False, + # Whether to squash the action output to space range + "squash_to_range": False, + + # == LSTM == + # Whether to wrap the model with a LSTM + "use_lstm": False, + # Max seq len for training the LSTM, defaults to 20 + "max_seq_len": 20, + # Size of the LSTM cell + "lstm_cell_size": 256, + + # == Atari == + # Whether to enable framestack for Atari envs + "framestack": True, + # Final resized frame dimension + "dim": 84, + # Pytorch conv requires images to be channel-major + "channel_major": False, + # (deprecated) Converts ATARI frame to 1 Channel Grayscale image + "grayscale": False, + # (deprecated) Changes frame to range from [-1, 1] if true + "zero_mean": True, # === Options for custom models === - "custom_preprocessor", # Name of a custom preprocessor to use - "custom_model", # Name of a custom model to use - "custom_options", # Extra options to pass to the custom classes -] + # Name of a custom preprocessor to use + "custom_preprocessor": None, + # Name of a custom model to use + "custom_model": None, + # Extra options to pass to the custom classes + "custom_options": {}, +} + +# __sphinx_doc_end__ class ModelCatalog(object): @@ -71,10 +94,7 @@ def get_action_dist(action_space, config, dist_type=None): dist_dim (int): The size of the input vector to the distribution. """ - # TODO(ekl) are list spaces valid? - if isinstance(action_space, list): - action_space = gym.spaces.Tuple(action_space) - config = config or {} + config = config or MODEL_DEFAULTS if isinstance(action_space, gym.spaces.Box): if dist_type is None: dist = DiagGaussian @@ -82,7 +102,7 @@ def get_action_dist(action_space, config, dist_type=None): dist = squash_to_range(dist, action_space.low, action_space.high) return dist, action_space.shape[0] * 2 - elif dist_type == 'deterministic': + elif dist_type == "deterministic": return Deterministic, action_space.shape[0] elif isinstance(action_space, gym.spaces.Discrete): return Categorical, action_space.n @@ -154,6 +174,7 @@ def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None): model (Model): Neural network model. """ + options = options or MODEL_DEFAULTS model = ModelCatalog._get_model(inputs, num_outputs, options, state_in, seq_lens) @@ -165,7 +186,7 @@ def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None): @staticmethod def _get_model(inputs, num_outputs, options, state_in, seq_lens): - if "custom_model" in options: + if options.get("custom_model"): model = options["custom_model"] print("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( @@ -183,7 +204,7 @@ def _get_model(inputs, num_outputs, options, state_in, seq_lens): return FullyConnectedNetwork(inputs, num_outputs, options) @staticmethod - def get_torch_model(input_shape, num_outputs, options={}): + def get_torch_model(input_shape, num_outputs, options=None): """Returns a PyTorch suitable model. This is currently only supported in A3C. @@ -200,7 +221,8 @@ def get_torch_model(input_shape, num_outputs, options={}): from ray.rllib.models.pytorch.visionnet import (VisionNetwork as PyTorchVisionNet) - if "custom_model" in options: + options = options or MODEL_DEFAULTS + if options.get("custom_model"): model = options["custom_model"] print("Using custom torch model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( @@ -217,7 +239,7 @@ def get_torch_model(input_shape, num_outputs, options={}): return PyTorchFCNet(input_shape[0], num_outputs, options) @staticmethod - def get_preprocessor(env, options={}): + def get_preprocessor(env, options=None): """Returns a suitable processor for the given environment. Args: @@ -227,12 +249,13 @@ def get_preprocessor(env, options={}): Returns: preprocessor (Preprocessor): Preprocessor for the env observations. """ + options = options or MODEL_DEFAULTS for k in options.keys(): - if k not in MODEL_CONFIGS: + if k not in MODEL_DEFAULTS: raise Exception("Unknown config key `{}`, all keys: {}".format( - k, MODEL_CONFIGS)) + k, list(MODEL_DEFAULTS))) - if "custom_preprocessor" in options: + if options.get("custom_preprocessor"): preprocessor = options["custom_preprocessor"] print("Using custom preprocessor {}".format(preprocessor)) return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( @@ -242,7 +265,7 @@ def get_preprocessor(env, options={}): return preprocessor(env.observation_space, options) @staticmethod - def get_preprocessor_as_wrapper(env, options={}): + def get_preprocessor_as_wrapper(env, options=None): """Returns a preprocessor as a gym observation wrapper. Args: @@ -253,6 +276,7 @@ def get_preprocessor_as_wrapper(env, options={}): wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. """ + options = options or MODEL_DEFAULTS preprocessor = ModelCatalog.get_preprocessor(env, options) return _RLlibPreprocessorWrapper(env, preprocessor) diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 11aee2c0da8f..e703fb0a080d 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -13,8 +13,8 @@ class FullyConnectedNetwork(Model): """Generic fully connected network.""" def _build_layers(self, inputs, num_outputs, options): - hiddens = options.get("fcnet_hiddens", [256, 256]) - activation = get_activation_fn(options.get("fcnet_activation", "tanh")) + hiddens = options.get("fcnet_hiddens") + activation = get_activation_fn(options.get("fcnet_activation")) with tf.name_scope("fc_net"): i = 1 diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 581569f0eff0..b8dea3ede95c 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -135,7 +135,7 @@ class LSTM(Model): """ def _build_layers(self, inputs, num_outputs, options): - cell_size = options.get("lstm_cell_size", 256) + cell_size = options.get("lstm_cell_size") last_layer = add_time_dimension(inputs, self.seq_lens) # Setup the LSTM cell diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 00d6575e6210..168f29c74625 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -55,7 +55,7 @@ def __init__(self, self.seq_lens = tf.placeholder( dtype=tf.int32, shape=[None], name="seq_lens") - if options.get("free_log_std", False): + if options.get("free_log_std"): assert num_outputs % 2 == 0 num_outputs = num_outputs // 2 self.outputs, self.last_layer = self._build_layers( diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index c400dd9805d3..cd72d1922dcb 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -30,12 +30,18 @@ def transform(self, observation): raise NotImplementedError -class AtariPixelPreprocessor(Preprocessor): +class GenericPixelPreprocessor(Preprocessor): + """Generic image preprocessor. + + Note: for Atari games, use config {"preprocessor_pref": "deepmind"} + instead for deepmind-style Atari preprocessing. + """ + def _init(self): - self._grayscale = self._options.get("grayscale", False) - self._zero_mean = self._options.get("zero_mean", True) - self._dim = self._options.get("dim", 84) - self._channel_major = self._options.get("channel_major", False) + self._grayscale = self._options.get("grayscale") + self._zero_mean = self._options.get("zero_mean") + self._dim = self._options.get("dim") + self._channel_major = self._options.get("channel_major") if self._grayscale: self.shape = (self._dim, self._dim, 1) else: @@ -130,7 +136,7 @@ def get_preprocessor(space): if isinstance(space, gym.spaces.Discrete): preprocessor = OneHotPreprocessor elif obs_shape == ATARI_OBS_SHAPE: - preprocessor = AtariPixelPreprocessor + preprocessor = GenericPixelPreprocessor elif obs_shape == ATARI_RAM_OBS_SHAPE: preprocessor = AtariRamPreprocessor elif isinstance(space, gym.spaces.Tuple): diff --git a/python/ray/rllib/models/pytorch/visionnet.py b/python/ray/rllib/models/pytorch/visionnet.py index 94ac8291d79a..e54c51897f2c 100644 --- a/python/ray/rllib/models/pytorch/visionnet.py +++ b/python/ray/rllib/models/pytorch/visionnet.py @@ -18,11 +18,11 @@ def _build_layers(self, inputs, num_outputs, options): inputs (tuple): (channels, rows/height, cols/width) num_outputs (int): logits size """ - filters = options.get("conv_filters", [ + filters = options.get("conv_filters") or [ [16, [8, 8], 4], [32, [4, 4], 2], [512, [11, 11], 1], - ]) + ] layers = [] in_channels, in_size = inputs[0], inputs[1:] diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 805d2e9e5ebe..902addb6a31c 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -17,7 +17,7 @@ def _build_layers(self, inputs, num_outputs, options): if not filters: filters = get_filter_config(options) - activation = get_activation_fn(options.get("conv_activation", "relu")) + activation = get_activation_fn(options.get("conv_activation")) with tf.name_scope("vision_net"): for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): @@ -57,7 +57,7 @@ def get_filter_config(options): [32, [4, 4], 2], [256, [11, 11], 1], ] - dim = options.get("dim", 84) + dim = options.get("dim") if dim == 84: return filters_84x84 elif dim == 42: diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 16bdd485f9a9..20ef872ae86e 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -19,10 +19,6 @@ Box(0.0, 1.0, (5, ), dtype=np.float32), Box(0.0, 1.0, (5, ), dtype=np.float32) ]), - "implicit_tuple": [ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ], "mixed_tuple": Tuple( [Discrete(2), Discrete(3), From a41bbc10ef5c3e0a8a3e8a36dbbe18171b556ada Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Tue, 16 Oct 2018 22:48:30 -0700 Subject: [PATCH 032/215] Add password authentication to Redis ports (#2952) * Implement Redis authentication * Throw exception for legacy Ray * Add test * Formatting * Fix bugs in CLI * Fix bugs in Raylet * Move default password to constants.h * Use pytest.fixture * Fix bug * Authenticate using formatted strings * Add missing passwords * Add test * Improve authentication of async contexts * Disable Redis authentication for credis * Update test for credis * Fix rebase artifacts * Fix formatting * Add workaround for issue #3045 * Increase timeout for test * Improve C++ readability * Fixes for CLI * Add security docs * Address comments * Address comments * Adress comments * Use ray.get * Fix lint --- .travis.yml | 2 + doc/source/index.rst | 1 + doc/source/security.rst | 55 +++++ python/ray/experimental/state.py | 9 +- python/ray/log_monitor.py | 21 +- python/ray/monitor.py | 27 ++- python/ray/scripts/scripts.py | 39 +++- python/ray/services.py | 188 +++++++++++++----- python/ray/test/test_ray_init.py | 65 ++++++ python/ray/worker.py | 47 ++++- python/ray/workers/default_worker.py | 12 +- src/ray/gcs/client.cc | 30 ++- src/ray/gcs/client.h | 7 +- src/ray/gcs/redis_context.cc | 31 ++- src/ray/gcs/redis_context.h | 3 +- src/ray/raylet/main.cc | 10 +- src/ray/raylet/monitor.cc | 4 +- src/ray/raylet/monitor.h | 2 +- src/ray/raylet/monitor_main.cc | 5 +- .../raylet/object_manager_integration_test.cc | 4 +- src/ray/raylet/raylet.cc | 10 +- src/ray/raylet/raylet.h | 5 +- 22 files changed, 462 insertions(+), 115 deletions(-) create mode 100644 doc/source/security.rst create mode 100644 python/ray/test/test_ray_init.py diff --git a/.travis.yml b/.travis.yml index 35743b764ea1..8416fa138d8c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -133,6 +133,7 @@ matrix: - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py + - python -m pytest -v python/ray/test/test_ray_init.py - python -m pytest -v test/xray_test.py - python -m pytest -v test/runtest.py @@ -208,6 +209,7 @@ script: - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py + - python -m pytest -v python/ray/test/test_ray_init.py - python -m pytest -v test/xray_test.py - python -m pytest -v test/runtest.py diff --git a/doc/source/index.rst b/doc/source/index.rst index d8870bbaa054..5268054b77eb 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -135,6 +135,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin troubleshooting.rst user-profiling.rst + security.rst development.rst profiling.rst contact.rst diff --git a/doc/source/security.rst b/doc/source/security.rst new file mode 100644 index 000000000000..6b636c66858e --- /dev/null +++ b/doc/source/security.rst @@ -0,0 +1,55 @@ +Security +======== + +This document describes best security practices for using Ray. + +Intended Use and Threat Model +----------------------------- + +Ray instances should run on a secure network without public facing ports. +The most common threat for Ray instances is unauthorized access to Redis, +which can be exploited to gain shell access and run arbitray code. +The best fix is to run Ray instances on a secure, trusted network. + +Running Ray on a secured network is not always feasible, so Ray +provides some basic security features: + + +Redis Port Authentication +------------------------- + +To prevent exploits via unauthorized Redis access, Ray provides the option to +password-protect Redis ports. While this is not a replacement for running Ray +behind a firewall, this feature is useful for instances exposed to the internet +where configuring a firewall is not possible. Because Redis is +very fast at serving queries, the chosen password should be long. + +Redis authentication is only supported on the raylet code path. + +To add authentication via the Python API, start Ray using: + +.. code-block:: python + + ray.init(redis_password="password") + +To add authentication via the CLI, or connect to an existing Ray instance with +password-protected Redis ports: + +.. code-block:: bash + + ray start [--head] --redis-password="password" + +While Redis port authentication may protect against external attackers, +Ray does not encrypt traffic between nodes so man-in-the-middle attacks are +possible for clusters on untrusted networks. + +Cloud Security +-------------- + +Launching Ray clusters on AWS or GCP using the ``ray up`` command +automatically configures security groups that prevent external Redis access. + +References +---------- + +- The `Redis security documentation ` diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index eab71993c60c..906d650d2866 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -78,6 +78,7 @@ def _check_connected(self): def _initialize_global_state(self, redis_ip_address, redis_port, + redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. @@ -89,9 +90,10 @@ def _initialize_global_state(self, redis_ip_address: The IP address of the node that the Redis server lives on. redis_port: The port that the Redis server is listening on. + redis_password: The password of the redis server. """ self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) start_time = time.time() @@ -143,7 +145,10 @@ def _initialize_global_state(self, for ip_address_port in ip_address_ports: shard_address, shard_port = ip_address_port.split(b":") self.redis_clients.append( - redis.StrictRedis(host=shard_address, port=shard_port)) + redis.StrictRedis( + host=shard_address, + port=shard_port, + password=redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 13a62a98a322..2cd6fc40a0f5 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -35,11 +35,15 @@ class LogMonitor(object): handle for that file. """ - def __init__(self, redis_ip_address, redis_port, node_ip_address): + def __init__(self, + redis_ip_address, + redis_port, + node_ip_address, + redis_password=None): """Initialize the log monitor object.""" self.node_ip_address = node_ip_address self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) self.log_files = {} self.log_file_handles = {} self.files_to_ignore = set() @@ -130,6 +134,12 @@ def run(self): required=True, type=str, help="The IP address of the node this process is on.") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -151,6 +161,9 @@ def run(self): redis_ip_address = get_ip_address(args.redis_address) redis_port = get_port(args.redis_address) - log_monitor = LogMonitor(redis_ip_address, redis_port, - args.node_ip_address) + log_monitor = LogMonitor( + redis_ip_address, + redis_port, + args.node_ip_address, + redis_password=args.redis_password) log_monitor.run() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index e5c2279b7233..6212de23e694 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -70,13 +70,18 @@ class Monitor(object): managers that were up at one point and have died since then. """ - def __init__(self, redis_address, redis_port, autoscaling_config): + def __init__(self, + redis_address, + redis_port, + autoscaling_config, + redis_password=None): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state(redis_address, redis_port) + self.state._initialize_global_state( + redis_address, redis_port, redis_password=redis_password) self.use_raylet = self.state.use_raylet self.redis = redis.StrictRedis( - host=redis_address, port=redis_port, db=0) + host=redis_address, port=redis_port, db=0, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) @@ -118,7 +123,9 @@ def __init__(self, redis_address, redis_port, autoscaling_config): else: addr_port = addr_port[0].split(b":") self.redis_shard = redis.StrictRedis( - host=addr_port[0], port=addr_port[1]) + host=addr_port[0], + port=addr_port[1], + password=redis_password) try: self.redis_shard.execute_command("HEAD.FLUSH 0") except redis.exceptions.ResponseError as e: @@ -773,6 +780,12 @@ def run(self): required=False, type=str, help="the path to the autoscaling config file") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -798,7 +811,11 @@ def run(self): else: autoscaling_config = None - monitor = Monitor(redis_ip_address, redis_port, autoscaling_config) + monitor = Monitor( + redis_ip_address, + redis_port, + autoscaling_config, + redis_password=args.redis_password) try: monitor.run() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index f8e0c5484f75..dfbbee272a27 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -89,6 +89,11 @@ def cli(logging_level, logging_format): type=int, help=("If provided, attempt to configure Redis with this " "maximum number of clients.")) +@click.option( + "--redis-password", + required=False, + type=str, + help="If provided, secure Redis ports with this password") @click.option( "--redis-shard-ports", required=False, @@ -190,10 +195,11 @@ def cli(logging_level, logging_format): default=None, help="manually specify the root temporary dir of the Ray process") def start(node_ip_address, redis_address, redis_port, num_redis_shards, - redis_max_clients, redis_shard_ports, object_manager_port, - object_store_memory, num_workers, num_cpus, num_gpus, resources, - head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, - use_raylet, no_redirect_worker_output, no_redirect_output, + redis_max_clients, redis_password, redis_shard_ports, + object_manager_port, object_store_memory, num_workers, num_cpus, + num_gpus, resources, head, no_ui, block, plasma_directory, + huge_pages, autoscaling_config, use_raylet, + no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir): # Convert hostnames to numerical IP address. if node_ip_address is not None: @@ -205,6 +211,11 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # This environment variable is used in our testing setup. logger.info("Detected environment variable 'RAY_USE_XRAY'.") use_raylet = True + if not use_raylet and redis_password is not None: + raise Exception("Setting the 'redis-password' argument is not " + "supported in legacy Ray. To run Ray with " + "password-protected Redis ports, pass " + "the '--use-raylet' flag.") try: resources = json.loads(resources) @@ -269,6 +280,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, redis_protected_mode=False, + redis_password=redis_password, include_webui=(not no_ui), plasma_directory=plasma_directory, huge_pages=huge_pages, @@ -281,16 +293,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, logger.info( "\nStarted Ray on this node. You can add additional nodes to " "the cluster by calling\n\n" - " ray start --redis-address {}\n\n" + " ray start --redis-address {}{}{}\n\n" "from the node you wish to add. You can connect a driver to the " "cluster from Python by running\n\n" " import ray\n" - " ray.init(redis_address=\"{}\")\n\n" + " ray.init(redis_address=\"{}{}{}\")\n\n" "If you have trouble connecting from a different machine, check " "that your firewall is configured properly. If you wish to " "terminate the processes that have been started, run\n\n" - " ray stop".format(address_info["redis_address"], - address_info["redis_address"])) + " ray stop".format( + address_info["redis_address"], " --redis-password " + if redis_password else "", redis_password if redis_password + else "", address_info["redis_address"], "\", redis_password=\"" + if redis_password else "", redis_password + if redis_password else "")) else: # Start Ray on a non-head node. if redis_port is not None: @@ -315,10 +331,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. - services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) + services.wait_for_redis_to_start( + redis_ip_address, int(redis_port), password=redis_password) # Create a Redis client. - redis_client = services.create_redis_client(redis_address) + redis_client = services.create_redis_client( + redis_address, password=redis_password) # Check that the verion information on this node matches the version # information that the cluster was started with. @@ -339,6 +357,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, object_manager_ports=[object_manager_port], num_workers=num_workers, object_store_memory=object_store_memory, + redis_password=redis_password, cleanup=False, redirect_worker_output=not no_redirect_worker_output, redirect_output=not no_redirect_output, diff --git a/python/ray/services.py b/python/ray/services.py index 9b1592e7bc29..e572b657f277 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -261,7 +261,10 @@ def get_node_ip_address(address="8.8.8.8:53"): return node_ip_address -def record_log_files_in_redis(redis_address, node_ip_address, log_files): +def record_log_files_in_redis(redis_address, + node_ip_address, + log_files, + password=None): """Record in Redis that a new log file has been created. This is used so that each log monitor can check Redis and figure out which @@ -273,23 +276,24 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): on. log_files: A list of file handles for the log files. If one of the file handles is None, we ignore it. + password (str): The password of the redis server. """ for log_file in log_files: if log_file is not None: redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=password) # The name of the key storing the list of log filenames for this IP # address. log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) redis_client.rpush(log_file_list_key, log_file.name) -def create_redis_client(redis_address): +def create_redis_client(redis_address, password=None): """Create a Redis client. Args: - The IP address and port of the Redis server. + The IP address, port, and password of the Redis server. Returns: A Redis client. @@ -297,10 +301,14 @@ def create_redis_client(redis_address): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine # as Redis) must have run "CONFIG SET protected-mode no". - return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + return redis.StrictRedis( + host=redis_ip_address, port=int(redis_port), password=password) -def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): +def wait_for_redis_to_start(redis_ip_address, + redis_port, + password=None, + num_retries=5): """Wait for a Redis server to be available. This is accomplished by creating a Redis client and sending a random @@ -309,13 +317,15 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): Args: redis_ip_address (str): The IP address of the redis server. redis_port (int): The port of the redis server. + password (str): The password of the redis server. num_retries (int): The number of times to try connecting with redis. The client will sleep for one second between attempts. Raises: Exception: An exception is raised if we could not connect with Redis. """ - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + redis_client = redis.StrictRedis( + host=redis_ip_address, port=redis_port, password=password) # Wait for the Redis server to start. counter = 0 while counter < num_retries: @@ -425,6 +435,7 @@ def start_redis(node_ip_address, redirect_worker_output=False, cleanup=True, protected_mode=False, + password=None, use_credis=None): """Start the Redis global state store. @@ -451,6 +462,8 @@ def start_redis(node_ip_address, then all Redis processes started by this method will be killed by services.cleanup() when the Python process that imported services exits. + password (str): Prevents external clients without the password + from connecting to Redis if provided. use_credis: If True, additionally load the chain-replicated libraries into the redis servers. Defaults to None, which means its value is set by the presence of "RAY_USE_NEW_GCS" in os.environ. @@ -469,6 +482,13 @@ def start_redis(node_ip_address, if use_credis is None: use_credis = ("RAY_USE_NEW_GCS" in os.environ) + if use_credis and password is not None: + # TODO(pschafhalter) remove this once credis supports + # authenticating Redis ports + raise Exception("Setting the `redis_password` argument is not " + "supported in credis. To run Ray with " + "password-protected Redis ports, ensure that " + "the environment variable `RAY_USE_NEW_GCS=off`.") if not use_credis: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -477,7 +497,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + protected_mode=protected_mode, + password=password) else: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -491,20 +512,23 @@ def start_redis(node_ip_address, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former # supplies. - modules=[CREDIS_MASTER_MODULE, REDIS_MODULE]) + modules=[CREDIS_MASTER_MODULE, REDIS_MODULE], + password=password) if port is not None: assert assigned_port == port port = assigned_port redis_address = address(node_ip_address, port) - redis_client = redis.StrictRedis(host=node_ip_address, port=port) + redis_client = redis.StrictRedis( + host=node_ip_address, port=port, password=password) # Store whether we're using the raylet code path or not. redis_client.set("UseRaylet", 1 if use_raylet else 0) # Register the number of Redis shards in the primary shard, so that clients # know how many redis shards to expect under RedisShards. - primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port) + primary_redis_client = redis.StrictRedis( + host=node_ip_address, port=port, password=password) primary_redis_client.set("NumRedisShards", str(num_redis_shards)) # Put the redirect_worker_output bool in the Redis shard so that workers @@ -529,7 +553,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + protected_mode=protected_mode, + password=password) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ @@ -542,6 +567,7 @@ def start_redis(node_ip_address, stderr_file=redis_stderr_file, cleanup=cleanup, protected_mode=protected_mode, + password=password, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray # module, as the latter contains an extern declaration that the @@ -557,7 +583,7 @@ def start_redis(node_ip_address, if use_credis: shard_client = redis.StrictRedis( - host=node_ip_address, port=redis_shard_port) + host=node_ip_address, port=redis_shard_port, password=password) # Configure the chain state. primary_redis_client.execute_command("MASTER.ADD", node_ip_address, redis_shard_port) @@ -591,6 +617,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stderr_file=None, cleanup=True, protected_mode=False, + password=None, executable=REDIS_EXECUTABLE, modules=None): """Start a single Redis server. @@ -614,6 +641,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", mode. This will prevent clients on other machines from connecting and is only used when the Redis servers are started via ray.init() as opposed to ray start. + password (str): Prevents external clients without the password + from connecting to Redis if provided. executable (str): Full path tho the redis-server executable. modules (list of str): A list of pathnames, pointing to the redis module(s) that will be loaded in this redis server. If None, load @@ -654,6 +683,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", command = [executable] if protected_mode: command += [redis_config_filename] + if password: + command += ["--requirepass", password] command += ( ["--port", str(port), "--loglevel", "warning"] + load_module_args) @@ -672,9 +703,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file.name, stderr_file.name)) # Create a Redis client just for configuring Redis. - redis_client = redis.StrictRedis(host="127.0.0.1", port=port) + redis_client = redis.StrictRedis( + host="127.0.0.1", port=port, password=password) # Wait for the Redis server to start. - wait_for_redis_to_start("127.0.0.1", port) + wait_for_redis_to_start("127.0.0.1", port, password=password) # Configure Redis to generate keyspace notifications. TODO(rkn): Change # this to only generate notifications for the export keys. redis_client.config_set("notify-keyspace-events", "Kl") @@ -719,8 +751,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1", redis_client.set("redis_start_time", time.time()) # Record the log files in Redis. record_log_files_in_redis( - address(node_ip_address, port), node_ip_address, - [stdout_file, stderr_file]) + address(node_ip_address, port), + node_ip_address, [stdout_file, stderr_file], + password=password) return port, p @@ -728,7 +761,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=cleanup): + cleanup=cleanup, + redis_password=None): """Start a log monitor process. Args: @@ -742,27 +776,31 @@ def start_log_monitor(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by services.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. """ log_monitor_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") - p = subprocess.Popen( - [ - sys.executable, "-u", log_monitor_filepath, "--redis-address", - redis_address, "--node-ip-address", node_ip_address - ], - stdout=stdout_file, - stderr=stderr_file) + command = [ + sys.executable, "-u", log_monitor_filepath, "--redis-address", + redis_address, "--node-ip-address", node_ip_address + ] + if redis_password: + command += ["--redis-password", redis_password] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Start a global scheduler process. Args: @@ -776,6 +814,7 @@ def start_global_scheduler(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by services.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. """ p = global_scheduler.start_global_scheduler( redis_address, @@ -784,8 +823,10 @@ def start_global_scheduler(redis_address, stderr_file=stderr_file) if cleanup: all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): @@ -911,7 +952,8 @@ def start_local_scheduler(redis_address, stderr_file=None, cleanup=True, resources=None, - num_workers=0): + num_workers=0, + redis_password=None): """Start a local scheduler process. Args: @@ -935,6 +977,7 @@ def start_local_scheduler(redis_address, quantity of that resource. num_workers (int): The number of workers that the local scheduler should start. + redis_password (str): The password of the redis server. Return: The name of the local scheduler socket. @@ -957,8 +1000,10 @@ def start_local_scheduler(redis_address, num_workers=num_workers) if cleanup: all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) return local_scheduler_name @@ -973,7 +1018,8 @@ def start_raylet(redis_address, use_profiler=False, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -996,6 +1042,7 @@ def start_raylet(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. Returns: The raylet socket name. @@ -1029,6 +1076,8 @@ def start_raylet(redis_address, sys.executable, worker_path, node_ip_address, plasma_store_name, raylet_name, redis_address, get_temp_root())) + if redis_password: + start_worker_command += " --redis-password {}".format(redis_password) command = [ RAYLET_EXECUTABLE, @@ -1042,6 +1091,7 @@ def start_raylet(redis_address, resource_argument, start_worker_command, "", # Worker command for Java, not needed for Python. + redis_password or "", ] if use_valgrind: @@ -1063,8 +1113,10 @@ def start_raylet(redis_address, if cleanup: all_processes[PROCESS_TYPE_RAYLET].append(pid) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) return raylet_name @@ -1081,7 +1133,8 @@ def start_plasma_store(node_ip_address, plasma_directory=None, huge_pages=False, use_raylet=False, - plasma_store_socket_name=None): + plasma_store_socket_name=None, + redis_password=None): """This method starts an object store process. Args: @@ -1111,6 +1164,7 @@ def start_plasma_store(node_ip_address, Store with hugetlbfs support. Requires plasma_directory. use_raylet: True if the new raylet code path should be used. This is not supported yet. + redis_password (str): The password of the redis server. Return: A tuple of the Plasma store socket name, the Plasma manager socket @@ -1186,8 +1240,10 @@ def start_plasma_store(node_ip_address, if cleanup: all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) - record_log_files_in_redis(redis_address, node_ip_address, - [store_stdout_file, store_stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [store_stdout_file, store_stderr_file], + password=redis_password) if not use_raylet: if cleanup: all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) @@ -1248,7 +1304,8 @@ def start_monitor(redis_address, stdout_file=None, stderr_file=None, cleanup=True, - autoscaling_config=None): + autoscaling_config=None, + redis_password=None): """Run a process to monitor the other processes. Args: @@ -1264,6 +1321,7 @@ def start_monitor(redis_address, Python process that imported services exits. This is True by default. autoscaling_config: path to autoscaling config file. + redis_password (str): The password of the redis server. """ monitor_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "monitor.py") @@ -1273,17 +1331,22 @@ def start_monitor(redis_address, ] if autoscaling_config: command.append("--autoscaling-config=" + str(autoscaling_config)) + if redis_password: + command.append("--redis-password=" + redis_password) p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_raylet_monitor(redis_address, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Run a process to monitor the other processes. Args: @@ -1296,8 +1359,10 @@ def start_raylet_monitor(redis_address, then this process will be killed by services.cleanup() when the Python process that imported services exits. This is True by default. + redis_password (str): The password of the redis server. """ gcs_ip_address, gcs_port = redis_address.split(":") + redis_password = redis_password or "" command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: @@ -1314,6 +1379,7 @@ def start_ray_processes(address_info=None, num_redis_shards=1, redis_max_clients=None, redis_protected_mode=False, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1359,6 +1425,8 @@ def start_ray_processes(address_info=None, redis_protected_mode: True if we should start Redis in protected mode. This will prevent clients from other machines from connecting and is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1444,7 +1512,8 @@ def start_ray_processes(address_info=None, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, - protected_mode=redis_protected_mode) + protected_mode=redis_protected_mode, + password=redis_password) address_info["redis_address"] = redis_address time.sleep(0.1) @@ -1457,18 +1526,20 @@ def start_ray_processes(address_info=None, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, cleanup=cleanup, - autoscaling_config=autoscaling_config) + autoscaling_config=autoscaling_config, + redis_password=redis_password) if use_raylet: start_raylet_monitor( redis_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) if redis_shards == []: # Get redis shards from primary redis instance. redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) redis_shards = [ray.utils.decode(shard) for shard in redis_shards] address_info["redis_shards"] = redis_shards @@ -1482,7 +1553,8 @@ def start_ray_processes(address_info=None, node_ip_address, stdout_file=log_monitor_stdout_file, stderr_file=log_monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) # Start the global scheduler, if necessary. if include_global_scheduler and not use_raylet: @@ -1493,7 +1565,8 @@ def start_ray_processes(address_info=None, node_ip_address, stdout_file=global_scheduler_stdout_file, stderr_file=global_scheduler_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) # Initialize with existing services. if "object_store_addresses" not in address_info: @@ -1537,7 +1610,8 @@ def start_ray_processes(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, use_raylet=use_raylet, - plasma_store_socket_name=plasma_store_socket_name) + plasma_store_socket_name=plasma_store_socket_name, + redis_password=redis_password) object_store_addresses.append(object_store_address) time.sleep(0.1) @@ -1575,7 +1649,8 @@ def start_ray_processes(address_info=None, stderr_file=local_scheduler_stderr_file, cleanup=cleanup, resources=resources[i], - num_workers=num_local_scheduler_workers) + num_workers=num_local_scheduler_workers, + redis_password=redis_password) local_scheduler_socket_names.append(local_scheduler_name) # Make sure that we have exactly num_local_schedulers instances of @@ -1599,7 +1674,8 @@ def start_ray_processes(address_info=None, num_workers=workers_per_local_scheduler[i], stdout_file=raylet_stdout_file, stderr_file=raylet_stderr_file, - cleanup=cleanup)) + cleanup=cleanup, + redis_password=redis_password)) if not use_raylet: # Start any workers that the local scheduler has not already started. @@ -1645,6 +1721,7 @@ def start_ray_node(node_ip_address, num_workers=0, num_local_schedulers=1, object_store_memory=None, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1673,6 +1750,8 @@ def start_ray_node(node_ip_address, start. object_store_memory (int): The maximum amount of memory (in bytes) to let the plasma store use. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1711,6 +1790,7 @@ def start_ray_node(node_ip_address, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_password=redis_password, worker_path=worker_path, include_log_monitor=True, cleanup=cleanup, @@ -1741,6 +1821,7 @@ def start_ray_head(address_info=None, num_redis_shards=None, redis_max_clients=None, redis_protected_mode=False, + redis_password=None, include_webui=True, plasma_directory=None, huge_pages=False, @@ -1792,6 +1873,8 @@ def start_ray_head(address_info=None, redis_protected_mode: True if we should start Redis in protected mode. This will prevent clients from other machines from connecting and is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. include_webui: True if the UI should be started and false otherwise. plasma_directory: A directory where the Plasma memory mapped files will be created. @@ -1832,6 +1915,7 @@ def start_ray_head(address_info=None, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, redis_protected_mode=redis_protected_mode, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, diff --git a/python/ray/test/test_ray_init.py b/python/ray/test/test_ray_init.py new file mode 100644 index 000000000000..62d581003b2e --- /dev/null +++ b/python/ray/test/test_ray_init.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import redis + +import ray + + +@pytest.fixture +def password(): + random_bytes = os.urandom(128) + if hasattr(random_bytes, "hex"): + return random_bytes.hex() # Python 3 + return random_bytes.encode("hex") # Python 2 + + +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestRedisPassword(object): + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") != "on" + and os.environ.get("RAY_USE_XRAY"), + reason="Redis authentication works for raylet and old GCS.") + def test_exceptions(self, password, shutdown_only): + with pytest.raises(Exception): + ray.init(redis_password=password) + + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="New GCS API doesn't support Redis authentication yet.") + @pytest.mark.skipif( + not os.environ.get("RAY_USE_XRAY"), + reason="Redis authentication is not supported in legacy Ray.") + def test_redis_password(self, password, shutdown_only): + # Workaround for https://github.com/ray-project/ray/issues/3045 + @ray.remote + def f(): + return 1 + + info = ray.init(redis_password=password) + redis_address = info["redis_address"] + redis_ip, redis_port = redis_address.split(":") + + # Check that we can run a task + object_id = f.remote() + ray.get(object_id) + + # Check that Redis connections require a password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=None) + with pytest.raises(redis.ResponseError): + redis_client.ping() + + # Check that we can connect to Redis using the provided password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=password) + assert redis_client.ping() diff --git a/python/ray/worker.py b/python/ray/worker.py index 4739b2e7cc7c..f30b464488ec 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1223,12 +1223,13 @@ def actor_handle_deserializer(serialized_obj): def get_address_info_from_redis_helper(redis_address, node_ip_address, - use_raylet=False): + use_raylet=False, + redis_password=None): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as # Redis) must have run "CONFIG SET protected-mode no". redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) + host=redis_ip_address, port=int(redis_port), password=redis_password) if not use_raylet: # The client table prefix must be kept in sync with the file @@ -1332,12 +1333,16 @@ def get_address_info_from_redis_helper(redis_address, def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - use_raylet=False): + use_raylet=False, + redis_password=None): counter = 0 while True: try: return get_address_info_from_redis_helper( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, + node_ip_address, + use_raylet=use_raylet, + redis_password=None) except Exception: if counter == num_retries: raise @@ -1405,6 +1410,7 @@ def _init(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, @@ -1460,6 +1466,8 @@ def _init(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object @@ -1544,6 +1552,7 @@ def _init(address_info=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, @@ -1596,7 +1605,10 @@ def _init(address_info=None, node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. address_info = get_address_info_from_redis( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, + node_ip_address, + use_raylet=use_raylet, + redis_password=redis_password) # Connect this driver to Redis, the object store, and the local scheduler. # Choose the first object store and local scheduler if there are multiple. @@ -1628,7 +1640,8 @@ def _init(address_info=None, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, - use_raylet=use_raylet) + use_raylet=use_raylet, + redis_password=redis_password) return address_info @@ -1647,6 +1660,7 @@ def init(redis_address=None, ignore_reinit_error=False, num_redis_shards=None, redis_max_clients=None, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, @@ -1709,6 +1723,8 @@ def init(redis_address=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object @@ -1750,6 +1766,11 @@ def init(redis_address=None, # This environment variable is used in our testing setup. logger.info("Detected environment variable 'RAY_USE_XRAY'.") use_raylet = True + if not use_raylet and redis_password is not None: + raise Exception("Setting the 'redis_password' argument is not " + "supported in legacy Ray. To run Ray with " + "password-protected Redis ports, set " + "'use_raylet=True'.") # Convert hostnames to numerical IP address. if node_ip_address is not None: @@ -1772,6 +1793,7 @@ def init(redis_address=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, @@ -1975,7 +1997,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, - use_raylet=False): + use_raylet=False, + redis_password=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -1986,6 +2009,8 @@ def connect(info, mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. use_raylet: True if the new raylet code path should be used. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" @@ -2019,7 +2044,10 @@ def connect(info, # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = thread_safe_client( - redis.StrictRedis(host=redis_ip_address, port=int(redis_port))) + redis.StrictRedis( + host=redis_ip_address, + port=int(redis_port), + password=redis_password)) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -2060,7 +2088,8 @@ def connect(info, [log_stdout_file, log_stderr_file]) # Create an object for interfacing with the global state. - global_state._initialize_global_state(redis_ip_address, int(redis_port)) + global_state._initialize_global_state( + redis_ip_address, int(redis_port), redis_password=redis_password) # Register the worker with Redis. if mode == SCRIPT_MODE: diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 72679722fa88..670ee092d0e5 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -24,6 +24,12 @@ required=True, type=str, help="the address to use for Redis") +parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--object-store-name", required=True, @@ -67,6 +73,7 @@ info = { "node_ip_address": args.node_ip_address, "redis_address": args.redis_address, + "redis_password": args.redis_password, "store_socket_name": args.object_store_name, "manager_socket_name": args.object_store_manager_name, "local_scheduler_socket_name": args.local_scheduler_name, @@ -81,7 +88,10 @@ tempfile_services.set_temp_root(args.temp_dir) ray.worker.connect( - info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None)) + info, + mode=ray.WORKER_MODE, + use_raylet=(args.raylet_name is not None), + redis_password=args.redis_password) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 182c44a8a8cf..3acd5623af17 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -71,10 +71,12 @@ namespace gcs { AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, CommandType command_type, - bool is_test_client = false) { + bool is_test_client = false, + const std::string &password = "") { primary_context_ = std::make_shared(); - RAY_CHECK_OK(primary_context_->Connect(address, port, /*sharding=*/true)); + RAY_CHECK_OK( + primary_context_->Connect(address, port, /*sharding=*/true, /*password=*/password)); if (!is_test_client) { // Moving sharding into constructor defaultly means that sharding = true. @@ -94,12 +96,13 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, RAY_CHECK(shard_contexts_.size() == addresses.size()); for (size_t i = 0; i < addresses.size(); ++i) { - RAY_CHECK_OK( - shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true)); + RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true, + /*password=*/password)); } } else { shard_contexts_.push_back(std::make_shared()); - RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true)); + RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true, + /*password=*/password)); } client_table_.reset(new ClientTable({primary_context_}, this, client_id)); @@ -126,12 +129,16 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, // Use of kChain currently only applies to Table::Add which affects only the // task table, and when RAY_USE_NEW_GCS is set at compile time. AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false) - : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client) {} + const ClientID &client_id, bool is_test_client = false, + const std::string &password = "") + : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client, + password) {} #else AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false) - : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client) {} + const ClientID &client_id, bool is_test_client = false, + const std::string &password = "") + : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client, + password) {} #endif // RAY_USE_NEW_GCS AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, @@ -143,8 +150,9 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, : AsyncGcsClient(address, port, ClientID::from_random(), command_type, is_test_client) {} -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port) - : AsyncGcsClient(address, port, ClientID::from_random()) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + const std::string &password = "") + : AsyncGcsClient(address, port, ClientID::from_random(), false, password) {} AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client) : AsyncGcsClient(address, port, ClientID::from_random(), is_test_client) {} diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index d89aadd803ea..83781e84127b 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -31,13 +31,14 @@ class RAY_EXPORT AsyncGcsClient { /// \param command_type GCS command type. If CommandType::kChain, chain-replicated /// versions of the tables might be used, if available. AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - CommandType command_type, bool is_test_client); + CommandType command_type, bool is_test_client, + const std::string &redis_password); AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - bool is_test_client); + bool is_test_client, const std::string &password); AsyncGcsClient(const std::string &address, int port, CommandType command_type); AsyncGcsClient(const std::string &address, int port, CommandType command_type, bool is_test_client); - AsyncGcsClient(const std::string &address, int port); + AsyncGcsClient(const std::string &address, int port, const std::string &password); AsyncGcsClient(const std::string &address, int port, bool is_test_client); /// Attach this client to a plasma event loop. Note that only diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index abc06a24a899..1a8111963256 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -135,7 +135,30 @@ RedisContext::~RedisContext() { } } -Status RedisContext::Connect(const std::string &address, int port, bool sharding) { +Status AuthenticateRedis(redisContext *context, const std::string &password) { + if (password == "") { + return Status::OK(); + } + redisReply *reply = + reinterpret_cast(redisCommand(context, "AUTH %s", password.c_str())); + REDIS_CHECK_ERROR(context, reply); + freeReplyObject(reply); + return Status::OK(); +} + +Status AuthenticateRedis(redisAsyncContext *context, const std::string &password) { + if (password == "") { + return Status::OK(); + } + int status = redisAsyncCommand(context, NULL, NULL, "AUTH %s", password.c_str()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(context->errstr)); + } + return Status::OK(); +} + +Status RedisContext::Connect(const std::string &address, int port, bool sharding, + const std::string &password = "") { int connection_attempts = 0; context_ = redisConnect(address.c_str(), port); while (context_ == nullptr || context_->err) { @@ -155,6 +178,8 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding context_ = redisConnect(address.c_str(), port); connection_attempts += 1; } + RAY_CHECK_OK(AuthenticateRedis(context_, password)); + redisReply *reply = reinterpret_cast( redisCommand(context_, "CONFIG SET notify-keyspace-events Kl")); REDIS_CHECK_ERROR(context_, reply); @@ -166,12 +191,16 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":" << port; } + RAY_CHECK_OK(AuthenticateRedis(async_context_, password)); + // Connect to subscribe context subscribe_context_ = redisAsyncConnect(address.c_str(), port); if (subscribe_context_ == nullptr || subscribe_context_->err) { RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " << address << ":" << port; } + RAY_CHECK_OK(AuthenticateRedis(subscribe_context_, password)); + return Status::OK(); } diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 67bc8197c302..1fcfd55adab1 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -51,7 +51,8 @@ class RedisContext { RedisContext() : context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {} ~RedisContext(); - Status Connect(const std::string &address, int port, bool sharding); + Status Connect(const std::string &address, int port, bool sharding, + const std::string &password); Status AttachToEventLoop(aeEventLoop *loop); /// Run an operation on some table key. diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 23aa41f25de5..8ad70a928e55 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -19,7 +19,7 @@ int main(int argc, char *argv[]) { ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); - RAY_CHECK(argc == 11); + RAY_CHECK(argc == 11 || argc == 12); const std::string raylet_socket_name = std::string(argv[1]); const std::string store_socket_name = std::string(argv[2]); @@ -31,6 +31,7 @@ int main(int argc, char *argv[]) { const std::string static_resource_list = std::string(argv[8]); const std::string python_worker_command = std::string(argv[9]); const std::string java_worker_command = std::string(argv[10]); + const std::string redis_password = (argc == 12 ? std::string(argv[11]) : ""); // Configuration for the node manager. ray::raylet::NodeManagerConfig node_manager_config; @@ -92,7 +93,8 @@ int main(int argc, char *argv[]) { << "object_chunk_size = " << object_manager_config.object_chunk_size; // initialize mock gcs & object directory - auto gcs_client = std::make_shared(redis_address, redis_port); + auto gcs_client = std::make_shared(redis_address, redis_port, + redis_password); RAY_LOG(DEBUG) << "Initializing GCS client " << gcs_client->client_table().GetLocalClientId(); @@ -100,8 +102,8 @@ int main(int argc, char *argv[]) { boost::asio::io_service main_service; ray::raylet::Raylet server(main_service, raylet_socket_name, node_ip_address, - redis_address, redis_port, node_manager_config, - object_manager_config, gcs_client); + redis_address, redis_port, redis_password, + node_manager_config, object_manager_config, gcs_client); // Destroy the Raylet on a SIGTERM. The pointer to main_service is // guaranteed to be valid since this function will run the event loop diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 05cf79309f2d..da9d5f8ab309 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -15,8 +15,8 @@ namespace raylet { /// the Ray configuration), then the monitor will mark that Raylet as dead in /// the client table, which broadcasts the event to all other Raylets. Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_address, - int redis_port) - : gcs_client_(redis_address, redis_port), + int redis_port, const std::string &redis_password) + : gcs_client_(redis_address, redis_port, redis_password), num_heartbeats_timeout_(RayConfig::instance().num_heartbeats_timeout()), heartbeat_timer_(io_service) { RAY_CHECK_OK(gcs_client_.Attach(io_service)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index 1786bc3f1a89..b300bf4cf8a0 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -19,7 +19,7 @@ class Monitor { /// \param redis_address The GCS Redis address to connect to. /// \param redis_port The GCS Redis port to connect to. Monitor(boost::asio::io_service &io_service, const std::string &redis_address, - int redis_port); + int redis_port, const std::string &redis_password); /// Start the monitor. Listen for heartbeats from Raylets and mark Raylets /// that do not send a heartbeat within a given period as dead. diff --git a/src/ray/raylet/monitor_main.cc b/src/ray/raylet/monitor_main.cc index 218faecd41ea..8cd82175285e 100644 --- a/src/ray/raylet/monitor_main.cc +++ b/src/ray/raylet/monitor_main.cc @@ -8,14 +8,15 @@ int main(int argc, char *argv[]) { ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); - RAY_CHECK(argc == 3); + RAY_CHECK(argc == 3 || argc == 4); const std::string redis_address = std::string(argv[1]); int redis_port = std::stoi(argv[2]); + const std::string redis_password = (argc == 4 ? std::string(argv[3]) : ""); // Initialize the monitor. boost::asio::io_service io_service; - ray::raylet::Monitor monitor(io_service, redis_address, redis_port); + ray::raylet::Monitor monitor(io_service, redis_address, redis_port, redis_password); monitor.Start(); io_service.run(); } diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 5fe7f774b625..d714b71ac2d6 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -60,7 +60,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_1.store_socket_name = store_sock_1; om_config_1.push_timeout_ms = 10000; server1.reset(new ray::raylet::Raylet( - main_service, "raylet_1", "0.0.0.0", "127.0.0.1", 6379, + main_service, "raylet_1", "0.0.0.0", "127.0.0.1", 6379, "", GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1)); // start second server @@ -70,7 +70,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_2.store_socket_name = store_sock_2; om_config_2.push_timeout_ms = 10000; server2.reset(new ray::raylet::Raylet( - main_service, "raylet_2", "0.0.0.0", "127.0.0.1", 6379, + main_service, "raylet_2", "0.0.0.0", "127.0.0.1", 6379, "", GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2)); // connect to stores. diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index df30498f4215..11b54b65bc8a 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -13,7 +13,8 @@ namespace raylet { Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_name, const std::string &node_ip_address, const std::string &redis_address, - int redis_port, const NodeManagerConfig &node_manager_config, + int redis_port, const std::string &redis_password, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) : gcs_client_(gcs_client), @@ -33,9 +34,9 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ DoAcceptObjectManager(); DoAcceptNodeManager(); - RAY_CHECK_OK(RegisterGcs(node_ip_address, socket_name_, - object_manager_config.store_socket_name, redis_address, - redis_port, main_service, node_manager_config)); + RAY_CHECK_OK(RegisterGcs( + node_ip_address, socket_name_, object_manager_config.store_socket_name, + redis_address, redis_port, redis_password, main_service, node_manager_config)); RAY_CHECK_OK(RegisterPeriodicTimer(main_service)); } @@ -52,6 +53,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const std::string &raylet_socket_name, const std::string &object_store_socket_name, const std::string &redis_address, int redis_port, + const std::string &redis_password, boost::asio::io_service &io_service, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index be634616b008..9b424781af17 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -29,6 +29,7 @@ class Raylet { /// \param node_ip_address The IP address of this node. /// \param redis_address The IP address of the redis instance we are connecting to. /// \param redis_port The port of the redis instance we are connecting to. + /// \param redis_password The password of the redis instance we are connecting to. /// \param node_manager_config Configuration to initialize the node manager. /// scheduler with. /// \param object_manager_config Configuration to initialize the object @@ -36,7 +37,8 @@ class Raylet { /// \param gcs_client A client connection to the GCS. Raylet(boost::asio::io_service &main_service, const std::string &socket_name, const std::string &node_ip_address, const std::string &redis_address, - int redis_port, const NodeManagerConfig &node_manager_config, + int redis_port, const std::string &redis_password, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client); @@ -49,6 +51,7 @@ class Raylet { const std::string &raylet_socket_name, const std::string &object_store_socket_name, const std::string &redis_address, int redis_port, + const std::string &redis_password, boost::asio::io_service &io_service, const NodeManagerConfig &); ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service); From 3c0803e7e909f2f617aa9bd8f3d266d31cbe3026 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Wed, 17 Oct 2018 17:44:51 -0700 Subject: [PATCH 033/215] [rllib] use `ray.wait` to get next worker result in async sample optimizer (#2993) --- .../optimizers/async_gradients_optimizer.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index fc7fdb2488a3..499d2a91f924 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -27,21 +27,26 @@ def _init(self, grads_per_step=100): def step(self): weights = ray.put(self.local_evaluator.get_weights()) - gradient_queue = [] + pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.remote_evaluators: e.set_weights.remote(weights) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + pending_gradients[future] = e num_gradients += 1 - # Note: can't use wait: https://github.com/ray-project/ray/issues/1128 - while gradient_queue: + while pending_gradients: with self.wait_timer: - fut, e = gradient_queue.pop(0) - gradient, info = ray.get(fut) + wait_results = ray.wait( + list(pending_gradients.keys()), num_returns=1) + ready_list = wait_results[0] + future = ready_list[0] + + gradient, info = ray.get(future) + e = pending_gradients.pop(future) + if "stats" in info: self.learner_stats = info["stats"] @@ -54,8 +59,9 @@ def step(self): if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote(self.local_evaluator.get_weights()) - fut = e.compute_gradients.remote(e.sample.remote()) - gradient_queue.append((fut, e)) + future = e.compute_gradients.remote(e.sample.remote()) + + pending_gradients[future] = e num_gradients += 1 def stats(self): From 2c52d9dfa043a168a00d0fe2083c73cef4b82228 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 17 Oct 2018 18:00:52 -0700 Subject: [PATCH 034/215] Fix actor handle id creation when actor handle was pickled (#3074) --- python/ray/actor.py | 13 +++++++++++++ test/actor_test.py | 22 ++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/ray/actor.py b/python/ray/actor.py index 65ddc266f944..d1f034cc6057 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -501,6 +501,7 @@ def __init__(self, self._ray_actor_method_cpus = actor_method_cpus self._ray_actor_driver_id = actor_driver_id self._ray_previous_actor_handle_id = previous_actor_handle_id + self._ray_previously_generated_actor_handle_id = None def _actor_method_call(self, method_name, @@ -554,10 +555,22 @@ def _actor_method_call(self, is_actor_checkpoint_method = (method_name == "__ray_checkpoint__") + # Right now, if the actor handle has been pickled, we create a + # temporary actor handle id for invocations. + # TODO(pcm): This still leads to a lot of actor handles being + # created, there should be a better way to handle pickled + # actor handles. if self._ray_actor_handle_id is None: actor_handle_id = compute_actor_handle_id_non_forked( self._ray_actor_id, self._ray_previous_actor_handle_id, worker.current_task_id) + # Each new task creates a new actor handle id, so we need to + # reset the actor counter to 0 + if (actor_handle_id != + self._ray_previously_generated_actor_handle_id): + self._ray_actor_counter = 0 + self._ray_previously_generated_actor_handle_id = ( + actor_handle_id) else: actor_handle_id = self._ray_actor_handle_id diff --git a/test/actor_test.py b/test/actor_test.py index 49d51835c31d..13aeb78f7fc3 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1968,6 +1968,28 @@ def method(self): ray.get(new_f.method.remote()) +def test_pickled_actor_handle_call_in_method_twice(ray_start_regular): + @ray.remote + class Actor1(object): + def f(self): + return 1 + + @ray.remote + class Actor2(object): + def __init__(self, constructor): + self.actor = constructor() + + def step(self): + ray.get(self.actor.f.remote()) + + a = Actor1.remote() + + b = Actor2.remote(lambda: a) + + ray.get(b.step.remote()) + ray.get(b.step.remote()) + + def test_register_and_get_named_actors(ray_start_regular): # TODO(heyucongtom): We should test this from another driver. From b82fd157a7480d3098327e15ef559c4f991f9ce5 Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Wed, 17 Oct 2018 22:48:14 -0700 Subject: [PATCH 035/215] Remove Redis protected mode (#3073) Follow-up to #2925 and #2952. Removes the Redis protected mode implementation from Ray which was replaced by Redis port authentication. --- python/ray/scripts/scripts.py | 1 - python/ray/services.py | 54 ++++----------------------------- python/ray/tempfile_services.py | 7 ----- 3 files changed, 6 insertions(+), 56 deletions(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index dfbbee272a27..b28fc4e179e4 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -279,7 +279,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - redis_protected_mode=False, redis_password=redis_password, include_webui=(not no_ui), plasma_directory=plasma_directory, diff --git a/python/ray/services.py b/python/ray/services.py index e572b657f277..a887b274386f 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -26,11 +26,11 @@ from ray.tempfile_services import ( get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name, - get_temp_redis_config_path, get_temp_root, new_global_scheduler_log_file, - new_local_scheduler_log_file, new_log_monitor_log_file, - new_monitor_log_file, new_plasma_manager_log_file, - new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file, - new_webui_log_file, new_worker_log_file, set_temp_root) + get_temp_root, new_global_scheduler_log_file, new_local_scheduler_log_file, + new_log_monitor_log_file, new_monitor_log_file, + new_plasma_manager_log_file, new_plasma_store_log_file, + new_raylet_log_file, new_redis_log_file, new_webui_log_file, + new_worker_log_file, set_temp_root) PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_LOG_MONITOR = "log_monitor" @@ -434,7 +434,6 @@ def start_redis(node_ip_address, redirect_output=False, redirect_worker_output=False, cleanup=True, - protected_mode=False, password=None, use_credis=None): """Start the Redis global state store. @@ -497,7 +496,6 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, password=password) else: assigned_port, _ = _start_redis_instance( @@ -507,7 +505,6 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former @@ -553,7 +550,6 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, password=password) else: assert num_redis_shards == 1, \ @@ -566,7 +562,6 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode, password=password, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray @@ -593,22 +588,6 @@ def start_redis(node_ip_address, return redis_address, redis_shards -def _make_temp_redis_config(node_ip_address): - """Create a configuration file for Redis. - - Args: - node_ip_address: The IP address of this node. This should not be - 127.0.0.1. - """ - redis_config_name = get_temp_redis_config_path() - with open(redis_config_name, 'w') as f: - # This allows redis clients on the same machine to connect using the - # node's IP address as opposed to just 127.0.0.1. This is only relevant - # when the server is in protected mode. - f.write("bind 127.0.0.1 {}".format(node_ip_address)) - return redis_config_name - - def _start_redis_instance(node_ip_address="127.0.0.1", port=None, redis_max_clients=None, @@ -616,7 +595,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file=None, stderr_file=None, cleanup=True, - protected_mode=False, password=None, executable=REDIS_EXECUTABLE, modules=None): @@ -637,10 +615,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. - protected_mode: True if we should start the Redis server in protected - mode. This will prevent clients on other machines from connecting - and is only used when the Redis servers are started via ray.init() - as opposed to ray start. password (str): Prevents external clients without the password from connecting to Redis if provided. executable (str): Full path tho the redis-server executable. @@ -668,9 +642,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", else: port = new_port() - if protected_mode: - redis_config_filename = _make_temp_redis_config(node_ip_address) - load_module_args = [] for module in modules: load_module_args += ["--loadmodule", module] @@ -681,8 +652,6 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # Construct the command to start the Redis server. command = [executable] - if protected_mode: - command += [redis_config_filename] if password: command += ["--requirepass", password] command += ( @@ -713,8 +682,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # Configure Redis to not run in protected mode so that processes on other # hosts can connect to it. TODO(rkn): Do this in a more secure way. - if not protected_mode: - redis_client.config_set("protected-mode", "no") + redis_client.config_set("protected-mode", "no") # If redis_max_clients is provided, attempt to raise the number of maximum # number of Redis clients. @@ -1378,7 +1346,6 @@ def start_ray_processes(address_info=None, object_store_memory=None, num_redis_shards=1, redis_max_clients=None, - redis_protected_mode=False, redis_password=None, worker_path=None, cleanup=True, @@ -1422,9 +1389,6 @@ def start_ray_processes(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. - redis_protected_mode: True if we should start Redis in protected mode. - This will prevent clients from other machines from connecting and - is only done when Redis is started via ray.init(). redis_password (str): Prevents external clients without the password from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the @@ -1512,7 +1476,6 @@ def start_ray_processes(address_info=None, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, - protected_mode=redis_protected_mode, password=redis_password) address_info["redis_address"] = redis_address time.sleep(0.1) @@ -1820,7 +1783,6 @@ def start_ray_head(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, - redis_protected_mode=False, redis_password=None, include_webui=True, plasma_directory=None, @@ -1870,9 +1832,6 @@ def start_ray_head(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. - redis_protected_mode: True if we should start Redis in protected mode. - This will prevent clients from other machines from connecting and - is only done when Redis is started via ray.init(). redis_password (str): Prevents external clients without the password from connecting to Redis if provided. include_webui: True if the UI should be started and false otherwise. @@ -1914,7 +1873,6 @@ def start_ray_head(address_info=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - redis_protected_mode=redis_protected_mode, redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py index 3b5adfa802db..76ec7c1d7ddb 100644 --- a/python/ray/tempfile_services.py +++ b/python/ray/tempfile_services.py @@ -156,13 +156,6 @@ def get_ipython_notebook_path(port): return new_notebook_directory, webui_url, token -def get_temp_redis_config_path(): - """Get a temp name of the redis config file.""" - redis_config_name = make_inc_temp( - prefix="redis_conf", directory_name=get_temp_root()) - return redis_config_name - - def new_log_files(name, redirect_output): """Generate partially randomized filenames for log files. From 653c5b114a26c91ba59c484179d51d6329fb8fad Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Fri, 19 Oct 2018 01:51:36 +0800 Subject: [PATCH 036/215] [c++] Refine Log Code (#2816) * Support setting logging level from env variable * Remove Env Variable related code * lint --- src/global_scheduler/global_scheduler.cc | 4 +- src/local_scheduler/local_scheduler.cc | 4 +- src/plasma/plasma_manager.cc | 4 +- src/ray/CMakeLists.txt | 4 +- src/ray/raylet/main.cc | 3 +- src/ray/raylet/monitor_main.cc | 4 +- src/ray/util/logging.cc | 74 +++++++++++++++--------- src/ray/util/logging.h | 74 ++++++++++-------------- src/ray/util/logging_test.cc | 4 +- src/ray/util/signal_test.cc | 3 +- 10 files changed, 92 insertions(+), 86 deletions(-) diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index 069ad6865d17..d964401ae720 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -455,8 +455,8 @@ void start_server(const char *node_ip_address, int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); + ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); signal(SIGTERM, signal_handler); /* IP address and port of the primary redis instance. */ diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 7bef00993ab9..d2c50c3fbb1d 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -1423,8 +1423,8 @@ void start_server( #ifndef LOCAL_SCHEDULER_TEST int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); + ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); signal(SIGTERM, signal_handler); /* Path of the listening socket of the local scheduler. */ diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index 51b18c572ee9..7d10f178c8e8 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -1626,8 +1626,8 @@ void signal_handler(int signal) { #ifndef PLASMA_TEST int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); + ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); signal(SIGTERM, signal_handler); /* Socket name of the plasma store this manager is connected to. */ diff --git a/src/ray/CMakeLists.txt b/src/ray/CMakeLists.txt index 526bc20a94d3..1c36c9f7198a 100644 --- a/src/ray/CMakeLists.txt +++ b/src/ray/CMakeLists.txt @@ -70,8 +70,8 @@ set(RAY_LIB_DEPENDENCIES gen_common_fbs) if(RAY_USE_GLOG) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DRAY_USE_GLOG") - set(RAY_LIB_STATIC_LINK_LIBS ${RAY_LIB_STATIC_LINK_LIBS} ${GLOG_STATIC_LIB}) + add_definitions(-DRAY_USE_GLOG) + set(RAY_LIB_STATIC_LINK_LIBS ${RAY_LIB_STATIC_LINK_LIBS} glog) set(RAY_LIB_DEPENDENCIES ${RAY_LIB_DEPENDENCIES} glog) endif() diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 8ad70a928e55..0fe3105d9bc8 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -16,7 +16,8 @@ static std::vector parse_worker_command(std::string worker_command) int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, + ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); RAY_CHECK(argc == 11 || argc == 12); diff --git a/src/ray/raylet/monitor_main.cc b/src/ray/raylet/monitor_main.cc index 8cd82175285e..f997566a5076 100644 --- a/src/ray/raylet/monitor_main.cc +++ b/src/ray/raylet/monitor_main.cc @@ -5,8 +5,8 @@ int main(int argc, char *argv[]) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, - /*log_dir=*/""); + ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); RAY_CHECK(argc == 3 || argc == 4); diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index ee42e7e3ab2f..b10c245e73ec 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -1,8 +1,12 @@ +#include "ray/util/logging.h" + +#ifndef _WIN32 +#include +#endif +#include #include #include -#include "ray/util/logging.h" - #ifdef RAY_USE_GLOG #include "glog/logging.h" #endif @@ -13,13 +17,13 @@ namespace ray { // which is independent of any libs. class CerrLog { public: - CerrLog(int severity) : severity_(severity), has_logged_(false) {} + CerrLog(RayLogLevel severity) : severity_(severity), has_logged_(false) {} virtual ~CerrLog() { if (has_logged_) { std::cerr << std::endl; } - if (severity_ == RAY_FATAL) { + if (severity_ == RayLogLevel::FATAL) { PrintBackTrace(); std::abort(); } @@ -32,7 +36,7 @@ class CerrLog { template CerrLog &operator<<(const T &t) { - if (severity_ != RAY_DEBUG) { + if (severity_ != RayLogLevel::DEBUG) { has_logged_ = true; std::cerr << t; } @@ -40,7 +44,7 @@ class CerrLog { } protected: - const int severity_; + const RayLogLevel severity_; bool has_logged_; void PrintBackTrace() { @@ -52,27 +56,33 @@ class CerrLog { } }; -int RayLog::severity_threshold_ = RAY_INFO; +#ifdef RAY_USE_GLOG +typedef google::LogMessage LoggingProvider; +#else +typedef ray::CerrLog LoggingProvider; +#endif + +RayLogLevel RayLog::severity_threshold_ = RayLogLevel::INFO; std::string RayLog::app_name_ = ""; #ifdef RAY_USE_GLOG using namespace google; // Glog's severity map. -static int GetMappedSeverity(int severity) { +static int GetMappedSeverity(RayLogLevel severity) { switch (severity) { - case RAY_DEBUG: + case RayLogLevel::DEBUG: return GLOG_INFO; - case RAY_INFO: + case RayLogLevel::INFO: return GLOG_INFO; - case RAY_WARNING: + case RayLogLevel::WARNING: return GLOG_WARNING; - case RAY_ERROR: + case RayLogLevel::ERROR: return GLOG_ERROR; - case RAY_FATAL: + case RayLogLevel::FATAL: return GLOG_FATAL; default: - RAY_LOG(FATAL) << "Unsupported logging level: " << severity; + RAY_LOG(FATAL) << "Unsupported logging level: " << static_cast(severity); // This return won't be hit but compiler needs it. return GLOG_FATAL; } @@ -80,11 +90,11 @@ static int GetMappedSeverity(int severity) { #endif -void RayLog::StartRayLog(const std::string &app_name, int severity_threshold, +void RayLog::StartRayLog(const std::string &app_name, RayLogLevel severity_threshold, const std::string &log_dir) { -#ifdef RAY_USE_GLOG severity_threshold_ = severity_threshold; app_name_ = app_name; +#ifdef RAY_USE_GLOG int mapped_severity_threshold = GetMappedSeverity(severity_threshold_); google::InitGoogleLogging(app_name_.c_str()); google::SetStderrLogging(mapped_severity_threshold); @@ -122,34 +132,44 @@ void RayLog::InstallFailureSignalHandler() { #endif } -bool RayLog::IsLevelEnabled(int log_level) { return log_level >= severity_threshold_; } +bool RayLog::IsLevelEnabled(RayLogLevel log_level) { + return log_level >= severity_threshold_; +} -RayLog::RayLog(const char *file_name, int line_number, int severity) - // glog does not have DEBUG level, we can handle it here. - : is_enabled_(severity >= severity_threshold_) { +RayLog::RayLog(const char *file_name, int line_number, RayLogLevel severity) + // glog does not have DEBUG level, we can handle it using is_enabled_. + : logging_provider_(nullptr), + is_enabled_(severity >= severity_threshold_) { #ifdef RAY_USE_GLOG if (is_enabled_) { - logging_provider_.reset( - new google::LogMessage(file_name, line_number, GetMappedSeverity(severity))); + logging_provider_ = + new google::LogMessage(file_name, line_number, GetMappedSeverity(severity)); } #else - logging_provider_.reset(new CerrLog(severity)); - *logging_provider_ << file_name << ":" << line_number << ": "; + auto logging_provider = new CerrLog(severity); + *logging_provider << file_name << ":" << line_number << ": "; + logging_provider_ = logging_provider; #endif } std::ostream &RayLog::Stream() { + auto logging_provider = reinterpret_cast(logging_provider_); #ifdef RAY_USE_GLOG // Before calling this function, user should check IsEnabled. // When IsEnabled == false, logging_provider_ will be empty. - return logging_provider_->stream(); + return logging_provider->stream(); #else - return logging_provider_->Stream(); + return logging_provider->Stream(); #endif } bool RayLog::IsEnabled() const { return is_enabled_; } -RayLog::~RayLog() { logging_provider_.reset(); } +RayLog::~RayLog() { + if (logging_provider_ != nullptr) { + delete reinterpret_cast(logging_provider_); + logging_provider_ = nullptr; + } +} } // namespace ray diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index daaa92369e1e..00c934c68a37 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -1,49 +1,26 @@ #ifndef RAY_UTIL_LOGGING_H #define RAY_UTIL_LOGGING_H -#ifndef _WIN32 -#include -#endif - -#include #include -#include - -#include "ray/util/macros.h" +#include -// Forward declaration for the log provider. -#ifdef RAY_USE_GLOG -namespace google { -class LogMessage; -} // namespace google -typedef google::LogMessage LoggingProvider; -#else namespace ray { -class CerrLog; -} // namespace ray -typedef ray::CerrLog LoggingProvider; -#endif -namespace ray { -// Log levels. LOG ignores them, so their values are abitrary. - -#define RAY_DEBUG (-1) -#define RAY_INFO 0 -#define RAY_WARNING 1 -#define RAY_ERROR 2 -#define RAY_FATAL 3 +enum class RayLogLevel { DEBUG = -1, INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 }; #define RAY_LOG_INTERNAL(level) ::ray::RayLog(__FILE__, __LINE__, level) -#define RAY_LOG(level) \ - if (ray::RayLog::IsLevelEnabled(RAY_##level)) RAY_LOG_INTERNAL(RAY_##level) +#define RAY_LOG(level) \ + if (ray::RayLog::IsLevelEnabled(ray::RayLogLevel::level)) \ + RAY_LOG_INTERNAL(ray::RayLogLevel::level) #define RAY_IGNORE_EXPR(expr) ((void)(expr)) -#define RAY_CHECK(condition) \ - (condition) ? RAY_IGNORE_EXPR(0) : ::ray::Voidify() & \ - ::ray::RayLog(__FILE__, __LINE__, RAY_FATAL) \ - << " Check failed: " #condition " " +#define RAY_CHECK(condition) \ + (condition) ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & \ + ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::FATAL) \ + << " Check failed: " #condition " " #ifdef NDEBUG @@ -67,14 +44,13 @@ class RayLogBase { public: virtual ~RayLogBase(){}; + // By default, this class is a null log because it return false here. virtual bool IsEnabled() const { return false; }; template RayLogBase &operator<<(const T &t) { if (IsEnabled()) { Stream() << t; - } else { - RAY_IGNORE_EXPR(t); } return *this; } @@ -85,7 +61,7 @@ class RayLogBase { class RayLog : public RayLogBase { public: - RayLog(const char *file_name, int line_number, int severity); + RayLog(const char *file_name, int line_number, RayLogLevel severity); virtual ~RayLog(); @@ -94,29 +70,37 @@ class RayLog : public RayLogBase { /// \return True if logging is enabled and false otherwise. virtual bool IsEnabled() const; - // The init function of ray log for a program which should be called only once. - // If logDir is empty, the log won't output to file. - static void StartRayLog(const std::string &appName, int severity_threshold = RAY_ERROR, + /// The init function of ray log for a program which should be called only once. + /// + /// \parem appName The app name which starts the log. + /// \param severity_threshold Logging threshold for the program. + /// \param logDir Logging output file name. If empty, the log won't output to file. + static void StartRayLog(const std::string &appName, + RayLogLevel severity_threshold = RayLogLevel::INFO, const std::string &logDir = ""); - // The shutdown function of ray log which should be used with StartRayLog as a pair. + /// The shutdown function of ray log which should be used with StartRayLog as a pair. static void ShutDownRayLog(); /// Return whether or not the log level is enabled in current setting. /// /// \param log_level The input log level to test. /// \return True if input log level is not lower than the threshold. - static bool IsLevelEnabled(int log_level); + static bool IsLevelEnabled(RayLogLevel log_level); - // Install the failure signal handler to output call stack when crash. - // If glog is not installed, this function won't do anything. + /// Install the failure signal handler to output call stack when crash. + /// If glog is not installed, this function won't do anything. static void InstallFailureSignalHandler(); + // Get the log level from environment variable. + static RayLogLevel GetLogLevelFromEnv(); private: - std::unique_ptr logging_provider_; + // Hide the implementation of log provider by void *. + // Otherwise, lib user may define the same macro to use the correct header file. + void *logging_provider_; /// True if log messages should be logged and false if they should be ignored. bool is_enabled_; - static int severity_threshold_; + static RayLogLevel severity_threshold_; // In InitGoogleLogging, it simply keeps the pointer. // We need to make sure the app name passed to InitGoogleLogging exist. static std::string app_name_; diff --git a/src/ray/util/logging_test.cc b/src/ray/util/logging_test.cc index a27a0625033f..85fc5ee06454 100644 --- a/src/ray/util/logging_test.cc +++ b/src/ray/util/logging_test.cc @@ -41,14 +41,14 @@ TEST(PrintLogTest, LogTestWithoutInit) { TEST(PrintLogTest, LogTestWithInit) { // Test empty app name. - RayLog::StartRayLog("", RAY_DEBUG); + RayLog::StartRayLog("", RayLogLevel::DEBUG); PrintLog(); RayLog::ShutDownRayLog(); } // This test will output large amount of logs to stderr, should be disabled in travis. TEST(LogPerfTest, PerfTest) { - RayLog::StartRayLog("/fake/path/to/appdire/LogPerfTest", RAY_ERROR, "/tmp/"); + RayLog::StartRayLog("/fake/path/to/appdire/LogPerfTest", RayLogLevel::ERROR, "/tmp/"); int rounds = 100000; int64_t start_time = current_time_ms(); diff --git a/src/ray/util/signal_test.cc b/src/ray/util/signal_test.cc index a408681d8f97..19e0aeec9ef7 100644 --- a/src/ray/util/signal_test.cc +++ b/src/ray/util/signal_test.cc @@ -82,7 +82,8 @@ TEST(SignalTest, SIGILL_Test) { int main(int argc, char **argv) { InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, - ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, + ray::RayLog::ShutDownRayLog, argv[0], + ray::RayLogLevel::INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); ::testing::InitGoogleTest(&argc, argv); From 8fcdafc6ea2b2daff69a8c0dc2e8969111b4fb00 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 18 Oct 2018 17:58:39 -0700 Subject: [PATCH 037/215] Adding Python3.7 wheels support (#2546) * Adding Python3.7 wheels support * Adding Mac wheels update * fix * numpy version * choose different numpy versions depending on python version * fix --- .travis/test-wheels.sh | 8 +++++--- python/build-wheel-macos.sh | 18 ++++++++++++++---- python/build-wheel-manylinux1.sh | 2 +- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/.travis/test-wheels.sh b/.travis/test-wheels.sh index 1b77209c3ddc..1765135ec9be 100755 --- a/.travis/test-wheels.sh +++ b/.travis/test-wheels.sh @@ -56,7 +56,7 @@ if [[ "$platform" == "linux" ]]; then # Check that the other wheels are present. NUMBER_OF_WHEELS=$(ls -1q $ROOT_DIR/../.whl/*.whl | wc -l) - if [[ "$NUMBER_OF_WHEELS" != "4" ]]; then + if [[ "$NUMBER_OF_WHEELS" != "5" ]]; then echo "Wrong number of wheels found." ls -l $ROOT_DIR/../.whl/ exit 1 @@ -67,12 +67,14 @@ elif [[ "$platform" == "macosx" ]]; then PY_MMS=("2.7" "3.4" "3.5" - "3.6") + "3.6" + "3.7") # This array is just used to find the right wheel. PY_WHEEL_VERSIONS=("27" "34" "35" - "36") + "36" + "37") for ((i=0; i<${#PY_MMS[@]}; ++i)); do PY_MM=${PY_MMS[i]} diff --git a/python/build-wheel-macos.sh b/python/build-wheel-macos.sh index 588362e8099e..30e8b1936376 100755 --- a/python/build-wheel-macos.sh +++ b/python/build-wheel-macos.sh @@ -16,15 +16,24 @@ DOWNLOAD_DIR=python_downloads PY_VERSIONS=("2.7.13" "3.4.4" "3.5.3" - "3.6.1") + "3.6.1" + "3.7.0") PY_INSTS=("python-2.7.13-macosx10.6.pkg" "python-3.4.4-macosx10.6.pkg" "python-3.5.3-macosx10.6.pkg" - "python-3.6.1-macosx10.6.pkg") + "python-3.6.1-macosx10.6.pkg" + "python-3.7.0-macosx10.6.pkg") PY_MMS=("2.7" "3.4" "3.5" - "3.6") + "3.6" + "3.7") +# On python 3.7, a newer version of numpy seems to be necessary. +NUMPY_VERSIONS=("1.10.4" + "1.10.4" + "1.10.4" + "1.10.4" + "1.14.5") mkdir -p $DOWNLOAD_DIR mkdir -p .whl @@ -33,6 +42,7 @@ for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do PY_VERSION=${PY_VERSIONS[i]} PY_INST=${PY_INSTS[i]} PY_MM=${PY_MMS[i]} + NUMPY_VERSION=${NUMPY_VERSIONS[i]} # The -f flag is passed twice to also run git clean in the arrow subdirectory. # The -d flag removes directories. The -x flag ignores the .gitignore file, @@ -60,7 +70,7 @@ for ((i=0; i<${#PY_VERSIONS[@]}; ++i)); do $PIP_CMD install -q setuptools_scm==2.1.0 # Fix the numpy version because this will be the oldest numpy version we can # support. - $PIP_CMD install -q numpy==1.10.4 cython==0.27.3 + $PIP_CMD install -q numpy==$NUMPY_VERSION cython==0.27.3 # Install wheel to avoid the error "invalid command 'bdist_wheel'". $PIP_CMD install -q wheel # Add the correct Python to the path and build the wheel. This is only diff --git a/python/build-wheel-manylinux1.sh b/python/build-wheel-manylinux1.sh index 8fdee4a1a480..db31ff55a4e6 100755 --- a/python/build-wheel-manylinux1.sh +++ b/python/build-wheel-manylinux1.sh @@ -13,7 +13,7 @@ rm -f /usr/bin/python2 ln -s /opt/python/cp27-cp27m/bin/python2 /usr/bin/python2 mkdir .whl -for PYTHON in cp27-cp27mu cp34-cp34m cp35-cp35m cp36-cp36m; do +for PYTHON in cp27-cp27mu cp34-cp34m cp35-cp35m cp36-cp36m cp37-cp37m; do # The -f flag is passed twice to also run git clean in the arrow subdirectory. # The -d flag removes directories. The -x flag ignores the .gitignore file, # and the -e flag ensures that we don't remove the .whl directory. From fa469783d88c511ffa5857d4d58ee5e6e9fefa8d Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Thu, 18 Oct 2018 21:43:03 -0700 Subject: [PATCH 038/215] Fix bug when connecting to password-secured cluster (#3083) --- python/ray/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/worker.py b/python/ray/worker.py index f30b464488ec..7049b1f3a429 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1342,7 +1342,7 @@ def get_address_info_from_redis(redis_address, redis_address, node_ip_address, use_raylet=use_raylet, - redis_password=None) + redis_password=redis_password) except Exception: if counter == num_retries: raise From 9d23fa03c98f2847271bef1551c736abf03d2204 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 18 Oct 2018 21:56:22 -0700 Subject: [PATCH 039/215] [xray] All messages on main asio event loop should be written asynchronously (#3023) * copy over ref code * wip async writes * compiles * fix error handling * add test * amend * fix test * clang fmgt * clang format * wip * yapf * rename format script * test error * clangfmt * add test to list * warn * ref test * fix test * comment * add capture * Update client_connection.cc * wip * fix compile --- .travis.yml | 3 +- .travis/{yapf.sh => format.sh} | 9 + src/ray/common/client_connection.cc | 96 ++++++++- src/ray/common/client_connection.h | 62 +++++- src/ray/object_manager/object_manager.cc | 28 ++- src/ray/object_manager/object_manager.h | 6 +- .../object_manager_client_connection.cc | 2 +- .../object_manager_client_connection.h | 12 ++ src/ray/raylet/CMakeLists.txt | 1 + src/ray/raylet/client_connection_test.cc | 155 ++++++++++++++ src/ray/raylet/node_manager.cc | 192 +++++++++--------- src/ray/raylet/node_manager.h | 11 +- 12 files changed, 446 insertions(+), 131 deletions(-) rename .travis/{yapf.sh => format.sh} (83%) create mode 100644 src/ray/raylet/client_connection_test.cc diff --git a/.travis.yml b/.travis.yml index 8416fa138d8c..bd0fd929b733 100644 --- a/.travis.yml +++ b/.travis.yml @@ -54,7 +54,7 @@ matrix: - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504 - - .travis/yapf.sh --all + - .travis/format.sh --all - os: linux dist: trusty @@ -185,6 +185,7 @@ install: - ./src/ray/raylet/lineage_cache_test - ./src/ray/raylet/task_dependency_manager_test - ./src/ray/raylet/reconstruction_policy_test + - ./src/ray/raylet/client_connection_test - ./src/ray/util/logging_test --gtest_filter=PrintLogTest* - ./src/ray/util/signal_test diff --git a/.travis/yapf.sh b/.travis/format.sh similarity index 83% rename from .travis/yapf.sh rename to .travis/format.sh index d90aec89531d..ca92d5196d56 100755 --- a/.travis/yapf.sh +++ b/.travis/format.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. # Cause the script to exit if a single command fails set -eo pipefail @@ -51,6 +53,13 @@ format_changed() { git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" fi + + if which clang-format >/dev/null; then + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \ + clang-format -i + fi + fi } # Format all files, and print the diff to stdout for travis. diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index eaa479429270..3347d1cfd35b 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -1,5 +1,6 @@ #include "client_connection.h" +#include #include #include "common.h" @@ -18,9 +19,19 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, return boost_to_ray_status(error); } +template +std::shared_ptr> ServerConnection::Create( + boost::asio::basic_stream_socket &&socket) { + std::shared_ptr> self(new ServerConnection(std::move(socket))); + return self; +} + template ServerConnection::ServerConnection(boost::asio::basic_stream_socket &&socket) - : socket_(std::move(socket)) {} + : socket_(std::move(socket)), + async_write_max_messages_(1), + async_write_queue_(), + async_write_in_flight_(false) {} template Status ServerConnection::WriteBuffer( @@ -78,11 +89,80 @@ ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, message_buffers.push_back(boost::asio::buffer(&type, sizeof(type))); message_buffers.push_back(boost::asio::buffer(&length, sizeof(length))); message_buffers.push_back(boost::asio::buffer(message, length)); - // Write the message and then wait for more messages. - // TODO(swang): Does this need to be an async write? return WriteBuffer(message_buffers); } +template +void ServerConnection::WriteMessageAsync( + int64_t type, int64_t length, const uint8_t *message, + const std::function &handler) { + auto write_buffer = std::unique_ptr(new AsyncWriteBuffer()); + write_buffer->write_version = RayConfig::instance().ray_protocol_version(); + write_buffer->write_type = type; + write_buffer->write_length = length; + write_buffer->write_message.resize(length); + write_buffer->write_message.assign(message, message + length); + write_buffer->handler = handler; + + auto size = async_write_queue_.size(); + auto size_is_power_of_two = (size & (size - 1)) == 0; + if (size > 100 && size_is_power_of_two) { + RAY_LOG(WARNING) << "ServerConnection has " << size << " buffered async writes"; + } + + async_write_queue_.push_back(std::move(write_buffer)); + + if (!async_write_in_flight_) { + DoAsyncWrites(); + } +} + +template +void ServerConnection::DoAsyncWrites() { + // Make sure we were not writing to the socket. + RAY_CHECK(!async_write_in_flight_); + async_write_in_flight_ = true; + + // Do an async write of everything currently in the queue to the socket. + std::vector message_buffers; + int num_messages = 0; + for (const auto &write_buffer : async_write_queue_) { + message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version, + sizeof(write_buffer->write_version))); + message_buffers.push_back( + boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type))); + message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length, + sizeof(write_buffer->write_length))); + message_buffers.push_back(boost::asio::buffer(write_buffer->write_message)); + num_messages++; + if (num_messages >= async_write_max_messages_) { + break; + } + } + auto this_ptr = this->shared_from_this(); + boost::asio::async_write( + ServerConnection::socket_, message_buffers, + [this, this_ptr, num_messages](const boost::system::error_code &error, + size_t bytes_transferred) { + ray::Status status = ray::Status::OK(); + if (error.value() != boost::system::errc::errc_t::success) { + status = boost_to_ray_status(error); + } + // Call the handlers for the written messages. + for (int i = 0; i < num_messages; i++) { + auto write_buffer = std::move(async_write_queue_.front()); + write_buffer->handler(status); + async_write_queue_.pop_front(); + } + // We finished writing, so mark that we're no longer doing an async write. + async_write_in_flight_ = false; + // If there is more to write, try to write the rest. + if (!async_write_queue_.empty()) { + DoAsyncWrites(); + } + }); +} + template std::shared_ptr> ClientConnection::Create( ClientHandler &client_handler, MessageHandler &message_handler, @@ -122,8 +202,8 @@ void ClientConnection::ProcessMessages() { header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_))); boost::asio::async_read( ServerConnection::socket_, header, - boost::bind(&ClientConnection::ProcessMessageHeader, this->shared_from_this(), - boost::asio::placeholders::error)); + boost::bind(&ClientConnection::ProcessMessageHeader, + shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } template @@ -143,8 +223,8 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & // Wait for the message to be read. boost::asio::async_read( ServerConnection::socket_, boost::asio::buffer(read_message_), - boost::bind(&ClientConnection::ProcessMessage, this->shared_from_this(), - boost::asio::placeholders::error)); + boost::bind(&ClientConnection::ProcessMessage, + shared_ClientConnection_from_this(), boost::asio::placeholders::error)); } template @@ -154,7 +234,7 @@ void ClientConnection::ProcessMessage(const boost::system::error_code &error) } uint64_t start_ms = current_time_ms(); - message_handler_(this->shared_from_this(), read_type_, read_message_.data()); + message_handler_(shared_ClientConnection_from_this(), read_type_, read_message_.data()); uint64_t interval = current_time_ms() - start_ms; if (interval > RayConfig::instance().handler_warning_timeout_ms()) { RAY_LOG(WARNING) << "[" << debug_label_ << "]ProcessMessage with type " << read_type_ diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 20b232c333f0..83c9849d9a8f 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -1,6 +1,7 @@ #ifndef RAY_COMMON_CLIENT_CONNECTION_H #define RAY_COMMON_CLIENT_CONNECTION_H +#include #include #include @@ -26,10 +27,14 @@ ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, /// A generic type representing a client connection to a server. This typename /// can be used to write messages synchronously to the server. template -class ServerConnection { +class ServerConnection : public std::enable_shared_from_this> { public: - /// Create a connection to the server. - ServerConnection(boost::asio::basic_stream_socket &&socket); + /// Allocate a new server connection. + /// + /// \param socket A reference to the server socket. + /// \return std::shared_ptr. + static std::shared_ptr> Create( + boost::asio::basic_stream_socket &&socket); /// Write a message to the client. /// @@ -39,6 +44,15 @@ class ServerConnection { /// \return Status. ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message); + /// Write a message to the client asynchronously. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \param handler A callback to run on write completion. + void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, + const std::function &handler); + /// Write a buffer to this connection. /// /// \param buffer The buffer. @@ -52,9 +66,42 @@ class ServerConnection { void ReadBuffer(const std::vector &buffer, boost::system::error_code &ec); + /// Shuts down socket for this connection. + void Close() { + boost::system::error_code ec; + socket_.close(ec); + } + protected: + /// A private constructor for a server connection. + ServerConnection(boost::asio::basic_stream_socket &&socket); + + /// A message that is queued for writing asynchronously. + struct AsyncWriteBuffer { + int64_t write_version; + int64_t write_type; + uint64_t write_length; + std::vector write_message; + std::function handler; + }; + /// The socket connection to the server. boost::asio::basic_stream_socket socket_; + + /// Max number of messages to write out at once. + const int async_write_max_messages_; + + /// List of pending messages to write. + std::list> async_write_queue_; + + /// Whether we are in the middle of an async write. + bool async_write_in_flight_; + + private: + /// Asynchronously flushes the write queue. While async writes are running, the flag + /// async_write_in_flight_ will be set. This should only be called when no async writes + /// are currently in flight. + void DoAsyncWrites(); }; template @@ -72,9 +119,10 @@ using MessageHandler = /// writing messages to the client, like in ServerConnection, this typename can /// also be used to process messages asynchronously from client. template -class ClientConnection : public ServerConnection, - public std::enable_shared_from_this> { +class ClientConnection : public ServerConnection { public: + using std::enable_shared_from_this>::shared_from_this; + /// Allocate a new node client connection. /// /// \param new_client_handler A reference to the client handler. @@ -85,6 +133,10 @@ class ClientConnection : public ServerConnection, ClientHandler &new_client_handler, MessageHandler &message_handler, boost::asio::basic_stream_socket &&socket, const std::string &debug_label); + std::shared_ptr> shared_ClientConnection_from_this() { + return std::static_pointer_cast>(shared_from_this()); + } + /// \return The ClientID of the remote client. const ClientID &GetClientID(); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index e6674fbf15d4..84be1106602b 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -266,35 +266,31 @@ void ObjectManager::PullEstablishConnection(const ObjectID &object_id, } connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE, client_id, async_conn); - Status pull_send_status = PullSendRequest(object_id, async_conn); - if (!pull_send_status.ok()) { - CheckIOError(pull_send_status, "Pull"); - } + PullSendRequest(object_id, async_conn); }, []() { RAY_LOG(ERROR) << "Failed to establish connection with remote object manager."; }); } else { - status = PullSendRequest(object_id, conn); - if (!status.ok()) { - CheckIOError(status, "Pull"); - } + PullSendRequest(object_id, conn); } } -ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id, - std::shared_ptr &conn) { +void ObjectManager::PullSendRequest(const ObjectID &object_id, + std::shared_ptr &conn) { flatbuffers::FlatBufferBuilder fbb; auto message = object_manager_protocol::CreatePullRequestMessage( fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary())); fbb.Finish(message); - Status status = conn->WriteMessage( + conn->WriteMessageAsync( static_cast(object_manager_protocol::MessageType::PullRequest), - fbb.GetSize(), fbb.GetBufferPointer()); - if (status.ok()) { - connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn); - } - return status; + fbb.GetSize(), fbb.GetBufferPointer(), [this, conn](ray::Status status) mutable { + if (status.ok()) { + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn); + } else { + CheckIOError(status, "Pull"); + } + }); } void ObjectManager::HandlePushTaskTimeout(const ObjectID &object_id, diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 11b5d7a6cd8a..ef5b98a03d06 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -245,10 +245,10 @@ class ObjectManager : public ObjectManagerInterface { /// Executes on main_service_ thread. void PullEstablishConnection(const ObjectID &object_id, const ClientID &client_id); - /// Synchronously send a pull request via remote object manager connection. + /// Asynchronously send a pull request via remote object manager connection. /// Executes on main_service_ thread. - ray::Status PullSendRequest(const ObjectID &object_id, - std::shared_ptr &conn); + void PullSendRequest(const ObjectID &object_id, + std::shared_ptr &conn); std::shared_ptr CreateSenderConnection( ConnectionPool::ConnectionType type, RemoteConnectionInfo info); diff --git a/src/ray/object_manager/object_manager_client_connection.cc b/src/ray/object_manager/object_manager_client_connection.cc index c612e1703cc3..dadfd72cef10 100644 --- a/src/ray/object_manager/object_manager_client_connection.cc +++ b/src/ray/object_manager/object_manager_client_connection.cc @@ -11,7 +11,7 @@ std::shared_ptr SenderConnection::Create( Status status = TcpConnect(socket, ip, port); if (status.ok()) { std::shared_ptr conn = - std::make_shared(std::move(socket)); + TcpServerConnection::Create(std::move(socket)); return std::make_shared(std::move(conn), client_id); } else { return nullptr; diff --git a/src/ray/object_manager/object_manager_client_connection.h b/src/ray/object_manager/object_manager_client_connection.h index 1c8661b0ddc0..b3a03102a728 100644 --- a/src/ray/object_manager/object_manager_client_connection.h +++ b/src/ray/object_manager/object_manager_client_connection.h @@ -16,6 +16,7 @@ namespace ray { +// TODO(ekl) this class can be replaced with a plain ClientConnection class SenderConnection : public boost::enable_shared_from_this { public: /// Create a connection for sending data to other object managers. @@ -44,6 +45,17 @@ class SenderConnection : public boost::enable_shared_from_this return conn_->WriteMessage(type, length, message); } + /// Write a message to the client asynchronously. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \param handler A callback to run on write completion. + void WriteMessageAsync(int64_t type, int64_t length, const uint8_t *message, + const std::function &handler) { + conn_->WriteMessageAsync(type, length, message, handler); + } + /// Write a buffer to this connection. /// /// \param buffer The buffer. diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 79233965af0c..5b580e4a2331 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -32,6 +32,7 @@ ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASM ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(client_connection_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(task_dependency_manager_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc new file mode 100644 index 000000000000..a68a6535c1e8 --- /dev/null +++ b/src/ray/raylet/client_connection_test.cc @@ -0,0 +1,155 @@ +#include +#include + +#include +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "ray/common/client_connection.h" + +namespace ray { +namespace raylet { + +class ClientConnectionTest : public ::testing::Test { + public: + ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_) { + boost::asio::local::connect_pair(in_, out_); + } + + protected: + boost::asio::io_service io_service_; + boost::asio::local::stream_protocol::socket in_; + boost::asio::local::stream_protocol::socket out_; +}; + +TEST_F(ClientConnectionTest, SimpleSyncWrite) { + const uint8_t arr[5] = {1, 2, 3, 4, 5}; + int num_messages = 0; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler message_handler = + [&arr, &num_messages](std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + ASSERT_TRUE(!std::memcmp(arr, message, 5)); + num_messages += 1; + }; + + auto conn1 = LocalClientConnection::Create(client_handler, message_handler, + std::move(in_), "conn1"); + + auto conn2 = LocalClientConnection::Create(client_handler, message_handler, + std::move(out_), "conn2"); + + RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr)); + RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr)); + conn1->ProcessMessages(); + conn2->ProcessMessages(); + io_service_.run(); + ASSERT_EQ(num_messages, 2); +} + +TEST_F(ClientConnectionTest, SimpleAsyncWrite) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + const uint8_t msg2[5] = {4, 4, 4, 4, 4}; + const uint8_t msg3[5] = {8, 8, 8, 8, 8}; + int num_messages = 0; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + std::shared_ptr reader = NULL; + + MessageHandler message_handler = + [&msg1, &msg2, &msg3, &num_messages, &reader]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + if (num_messages == 0) { + ASSERT_TRUE(!std::memcmp(msg1, message, 5)); + } else if (num_messages == 1) { + ASSERT_TRUE(!std::memcmp(msg2, message, 5)); + } else { + ASSERT_TRUE(!std::memcmp(msg3, message, 5)); + } + num_messages += 1; + if (num_messages < 3) { + reader->ProcessMessages(); + } + }; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_), + "reader"); + + std::function callback = [](const ray::Status &status) { + RAY_CHECK_OK(status); + }; + + writer->WriteMessageAsync(0, 5, msg1, callback); + writer->WriteMessageAsync(0, 5, msg2, callback); + writer->WriteMessageAsync(0, 5, msg3, callback); + reader->ProcessMessages(); + io_service_.run(); + ASSERT_EQ(num_messages, 3); +} + +TEST_F(ClientConnectionTest, SimpleAsyncError) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + std::function callback = [](const ray::Status &status) { + ASSERT_TRUE(!status.ok()); + }; + + writer->Close(); + writer->WriteMessageAsync(0, 5, msg1, callback); + io_service_.run(); +} + +TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { + const uint8_t msg1[5] = {1, 2, 3, 4, 5}; + + ClientHandler client_handler = + [](LocalClientConnection &client) {}; + + MessageHandler noop_handler = []( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; + + auto writer = LocalClientConnection::Create(client_handler, noop_handler, + std::move(in_), "writer"); + + std::function callback = + [writer](const ray::Status &status) { + static_cast(writer); + ASSERT_TRUE(status.ok()); + }; + writer->WriteMessageAsync(0, 5, msg1, callback); + io_service_.run(); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c4e75dbb8c6a..be1974468439 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -359,7 +359,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { } // The client is connected. - auto server_conn = TcpServerConnection(std::move(socket)); + auto server_conn = TcpServerConnection::Create(std::move(socket)); remote_server_connections_.emplace(client_id, std::move(server_conn)); ResourceSet resources_total(client_data.resources_total_label, @@ -1304,56 +1304,59 @@ void NodeManager::AssignTask(Task &task) { auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(resource_id_set_flatbuf)); fbb.Finish(message); - auto status = worker->Connection()->WriteMessage( + worker->Connection()->WriteMessageAsync( static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), - fbb.GetBufferPointer()); - if (status.ok()) { - // We successfully assigned the task to the worker. - worker->AssignTaskId(spec.TaskId()); - worker->AssignDriverId(spec.DriverId()); - // If the task was an actor task, then record this execution to guarantee - // consistency in the case of reconstruction. - if (spec.IsActorTask()) { - auto actor_entry = actor_registry_.find(spec.ActorId()); - RAY_CHECK(actor_entry != actor_registry_.end()); - auto execution_dependency = actor_entry->second.GetExecutionDependency(); - // The execution dependency is initialized to the actor creation task's - // return value, and is subsequently updated to the assigned tasks' - // return values, so it should never be nil. - RAY_CHECK(!execution_dependency.is_nil()); - // Update the task's execution dependencies to reflect the actual - // execution order, to support deterministic reconstruction. - // NOTE(swang): The update of an actor task's execution dependencies is - // performed asynchronously. This means that if this node manager dies, - // we may lose updates that are in flight to the task table. We only - // guarantee deterministic reconstruction ordering for tasks whose - // updates are reflected in the task table. - task.SetExecutionDependencies({execution_dependency}); - // Extend the frontier to include the executing task. - actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject()); - } - // We started running the task, so the task is ready to write to GCS. - if (!lineage_cache_.AddReadyTask(task)) { - RAY_LOG(WARNING) - << "Task " << spec.TaskId() - << " already in lineage cache. This is most likely due to reconstruction."; - } - // Mark the task as running. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueRunningTasks(std::vector({task})); - // Notify the task dependency manager that we no longer need this task's - // object dependencies. - task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); - } else { - RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; - // We failed to send the task to the worker, so disconnect the worker. - ProcessDisconnectClientMessage(worker->Connection()); - // Queue this task for future assignment. The task will be assigned to a - // worker once one becomes available. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueReadyTasks(std::vector({task})); - DispatchTasks(); - } + fbb.GetBufferPointer(), [this, worker, task](ray::Status status) mutable { + if (status.ok()) { + auto spec = task.GetTaskSpecification(); + // We successfully assigned the task to the worker. + worker->AssignTaskId(spec.TaskId()); + worker->AssignDriverId(spec.DriverId()); + // If the task was an actor task, then record this execution to guarantee + // consistency in the case of reconstruction. + if (spec.IsActorTask()) { + auto actor_entry = actor_registry_.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry_.end()); + auto execution_dependency = actor_entry->second.GetExecutionDependency(); + // The execution dependency is initialized to the actor creation task's + // return value, and is subsequently updated to the assigned tasks' + // return values, so it should never be nil. + RAY_CHECK(!execution_dependency.is_nil()); + // Update the task's execution dependencies to reflect the actual + // execution order, to support deterministic reconstruction. + // NOTE(swang): The update of an actor task's execution dependencies is + // performed asynchronously. This means that if this node manager dies, + // we may lose updates that are in flight to the task table. We only + // guarantee deterministic reconstruction ordering for tasks whose + // updates are reflected in the task table. + task.SetExecutionDependencies({execution_dependency}); + // Extend the frontier to include the executing task. + actor_entry->second.ExtendFrontier(spec.ActorHandleId(), + spec.ActorDummyObject()); + } + // We started running the task, so the task is ready to write to GCS. + if (!lineage_cache_.AddReadyTask(task)) { + RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache. " + "This is most likely due to " + "reconstruction."; + } + // Mark the task as running. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueRunningTasks(std::vector({task})); + // Notify the task dependency manager that we no longer need this task's + // object dependencies. + task_dependency_manager_.UnsubscribeDependencies(spec.TaskId()); + } else { + RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; + // We failed to send the task to the worker, so disconnect the worker. + ProcessDisconnectClientMessage(worker->Connection()); + // Queue this task for future assignment. The task will be assigned to a + // worker once one becomes available. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueReadyTasks(std::vector({task})); + DispatchTasks(); + } + }); } void NodeManager::FinishAssignedTask(Worker &worker) { @@ -1522,10 +1525,10 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, const ClientID &node_manager_id) { /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? - const TaskID task_id = task.GetTaskSpecification().TaskId(); - // Attempt to forward the task. - if (!ForwardTask(task, node_manager_id).ok()) { + ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) { + const TaskID task_id = task.GetTaskSpecification().TaskId(); + RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " << node_manager_id; // Mark the failed task as pending to let other raylets know that we still @@ -1564,10 +1567,11 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, ScheduleTasks(cluster_resource_map_); DispatchTasks(); } - } + }); } -ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) { +void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, + const std::function &on_error) { const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -1593,49 +1597,53 @@ ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) if (it == remote_server_connections_.end()) { // TODO(atumanov): caller must handle failure to ensure tasks are not lost. RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - return ray::Status::IOError("NodeManager connection not found"); + on_error(ray::Status::IOError("NodeManager connection not found")); + return; } auto &server_conn = it->second; - auto status = server_conn.WriteMessage( + server_conn->WriteMessageAsync( static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer()); - if (status.ok()) { - // If we were able to forward the task, remove the forwarded task from the - // lineage cache since the receiving node is now responsible for writing - // the task to the GCS. - if (!lineage_cache_.RemoveWaitingTask(task_id)) { - RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage " - "cache. This is most likely due to " - "reconstruction."; - } - // Mark as forwarded so that the task and its lineage is not re-forwarded - // in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - - // Notify the task dependency manager that we are no longer responsible - // for executing this task. - task_dependency_manager_.TaskCanceled(task_id); - // Preemptively push any local arguments to the receiving node. For now, we - // only do this with actor tasks, since actor tasks must be executed by a - // specific process and therefore have affinity to the receiving node. - if (spec.IsActorTask()) { - // Iterate through the object's arguments. NOTE(swang): We do not include - // the execution dependencies here since those cannot be transferred - // between nodes. - for (int i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); - // If the argument is local, then push it to the receiving node. - if (task_dependency_manager_.CheckObjectLocal(argument_id)) { - object_manager_.Push(argument_id, node_id); + fbb.GetBufferPointer(), + [this, on_error, task_id, node_id, spec](ray::Status status) { + if (status.ok()) { + // If we were able to forward the task, remove the forwarded task from the + // lineage cache since the receiving node is now responsible for writing + // the task to the GCS. + if (!lineage_cache_.RemoveWaitingTask(task_id)) { + RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage " + "cache. This is most likely due to " + "reconstruction."; + } + // Mark as forwarded so that the task and its lineage is not re-forwarded + // in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + + // Notify the task dependency manager that we are no longer responsible + // for executing this task. + task_dependency_manager_.TaskCanceled(task_id); + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + object_manager_.Push(argument_id, node_id); + } + } + } } + } else { + on_error(status); } - } - } - } - return status; + }); } } // namespace raylet diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index e3d2ca1416ce..2e5d7605f12a 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -174,10 +174,10 @@ class NodeManager { /// /// \param task The task to forward. /// \param node_id The ID of the node to forward the task to. - /// \return A status indicating whether the forward succeeded or not. Note - /// that a status of OK is not a reliable indicator that the forward succeeded - /// or even that the remote node is still alive. - ray::Status ForwardTask(const Task &task, const ClientID &node_id); + /// \param on_error Callback on run on non-ok status. + void ForwardTask(const Task &task, const ClientID &node_id, + const std::function &on_error); + /// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to /// "running" task state. void DispatchTasks(); @@ -352,7 +352,8 @@ class NodeManager { /// The lineage cache for the GCS object and task tables. LineageCache lineage_cache_; std::vector remote_clients_; - std::unordered_map remote_server_connections_; + std::unordered_map> + remote_server_connections_; /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_; From b410ee0d291fe2fe838403030163dde3b8fe42bb Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Fri, 19 Oct 2018 21:22:32 +0800 Subject: [PATCH 040/215] [Java] Support dynamically defining resources when submitting task. (#3070) ## What do these changes do? Before this PR, if we want to specify some resources, we must do as following codes: ```java @RayRemote(Resources={ResourceItem("CPU", 10)}) public static void f1() { // do sth } @RayRemote(Resources={ResourceItem("CPU", 10)}) class Demo { // sth } ``` Unfortunately, it's no way for us to create another actor or task with different resources required. After this PR, the thing will be: ```java ActorCreationOptions option = new ActorCreationOptions(); option.resources.put("CPU", 4.0); RayActor echo1 = Ray.createActor(Echo::new, option); option.resources.put("Res-A", 4.0); RayActor echo2 = Ray.createActor(Echo::new, option); //if we don't specify resource, the resources will be `{"cpu":0.0}` by default. Ray.call(Echo::echo, echo2, 100); ``` ## Related issue number N/A --- .../src/main/java/org/ray/api/RayCall.java | 1528 ++++++++++++++--- .../org/ray/api/annotation/RayRemote.java | 6 - .../org/ray/api/annotation/ResourceItem.java | 28 - .../ray/api/options/ActorCreationOptions.java | 18 + .../org/ray/api/options/BaseTaskOptions.java | 20 + .../java/org/ray/api/options/CallOptions.java | 18 + .../java/org/ray/api/runtime/RayRuntime.java | 10 +- .../org/ray/runtime/AbstractRayRuntime.java | 34 +- .../org/ray/runtime/util/ResourceUtil.java | 50 +- .../util/generator/RayCallGenerator.java | 47 +- .../ray/api/test/ResourcesManagementTest.java | 73 +- 11 files changed, 1433 insertions(+), 399 deletions(-) delete mode 100644 java/api/src/main/java/org/ray/api/annotation/ResourceItem.java create mode 100644 java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java create mode 100644 java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java create mode 100644 java/api/src/main/java/org/ray/api/options/CallOptions.java diff --git a/java/api/src/main/java/org/ray/api/RayCall.java b/java/api/src/main/java/org/ray/api/RayCall.java index ef40a238c0e2..967830199402 100644 --- a/java/api/src/main/java/org/ray/api/RayCall.java +++ b/java/api/src/main/java/org/ray/api/RayCall.java @@ -2,6 +2,7 @@ package org.ray.api; +import org.ray.api.function.RayFunc; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; import org.ray.api.function.RayFunc2; @@ -9,6 +10,9 @@ import org.ray.api.function.RayFunc4; import org.ray.api.function.RayFunc5; import org.ray.api.function.RayFunc6; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; /** * This class provides type-safe interfaces for `Ray.call` and `Ray.createActor`. @@ -20,511 +24,1019 @@ class RayCall { // ======================================= public static RayObject call(RayFunc0 f) { Object[] args = new Object[]{}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc0 f, CallOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc1 f, T0 t0) { Object[] args = new Object[]{t0}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc1 f, RayObject t0) { Object[] args = new Object[]{t0}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc1 f, T0 t0, CallOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc1 f, RayObject t0, CallOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc2 f, T0 t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, T0 t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, RayObject t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc2 f, RayObject t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc2 f, T0 t0, T1 t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, T0 t0, RayObject t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, RayObject t0, T1 t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc2 f, RayObject t0, RayObject t1, CallOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc3 f, T0 t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc3 f, T0 t0, T1 t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, T1 t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, T0 t0, RayObject t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, T1 t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, T2 t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().call(f, args, options); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); } public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().call(f, args); + return Ray.internal().call(f, args, null); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); + } + public static RayObject call(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, CallOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().call(f, args, options); } // =========================================== // Methods for remote actor method invocation. @@ -786,510 +1298,1018 @@ public static RayObject call(RayFunc6 RayActor createActor(RayFunc0 f) { Object[] args = new Object[]{}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc0 f, ActorCreationOptions options) { + Object[] args = new Object[]{}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc1 f, T0 t0) { Object[] args = new Object[]{t0}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc1 f, RayObject t0) { Object[] args = new Object[]{t0}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc1 f, T0 t0, ActorCreationOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc1 f, RayObject t0, ActorCreationOptions options) { + Object[] args = new Object[]{t0}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc2 f, T0 t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, T0 t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, RayObject t0, T1 t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc2 f, RayObject t0, RayObject t1) { Object[] args = new Object[]{t0, t1}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc2 f, T0 t0, T1 t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, T0 t0, RayObject t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, RayObject t0, T1 t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc2 f, RayObject t0, RayObject t1, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, T2 t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2) { Object[] args = new Object[]{t0, t1, t2}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, T1 t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, T0 t0, RayObject t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, T1 t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, T2 t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc3 f, RayObject t0, RayObject t1, RayObject t2, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3) { Object[] args = new Object[]{t0, t1, t2, t3}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, T1 t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc4 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4) { Object[] args = new Object[]{t0, t1, t2, t3, t4}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc5 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4}; + return Ray.internal().createActor(f, args, options); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); } public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5) { Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; - return Ray.internal().createActor(f, args); + return Ray.internal().createActor(f, args, null); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, T0 t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, T1 t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, T2 t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, T3 t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, T4 t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, T5 t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); + } + public static RayActor createActor(RayFunc6 f, RayObject t0, RayObject t1, RayObject t2, RayObject t3, RayObject t4, RayObject t5, ActorCreationOptions options) { + Object[] args = new Object[]{t0, t1, t2, t3, t4, t5}; + return Ray.internal().createActor(f, args, options); } } diff --git a/java/api/src/main/java/org/ray/api/annotation/RayRemote.java b/java/api/src/main/java/org/ray/api/annotation/RayRemote.java index a47e0768f0fb..197ee663f58a 100644 --- a/java/api/src/main/java/org/ray/api/annotation/RayRemote.java +++ b/java/api/src/main/java/org/ray/api/annotation/RayRemote.java @@ -15,10 +15,4 @@ @Target({ElementType.METHOD, ElementType.TYPE}) public @interface RayRemote { - /** - * Defines the quantity of various custom resources to reserve - * for this task or for the lifetime of the actor. - * @return an array of custom resource items. - */ - ResourceItem[] resources() default {}; } diff --git a/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java b/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java deleted file mode 100644 index f4895eba6164..000000000000 --- a/java/api/src/main/java/org/ray/api/annotation/ResourceItem.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.ray.api.annotation; - - -import java.lang.annotation.Documented; -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * Represents a custom resource, including its name and quantity. - */ -@Documented -@Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.ANNOTATION_TYPE) -public @interface ResourceItem { - - /** - * Name of this resource, must not be null or empty. - */ - String name(); - - /** - * Quantity of this resource. - */ - double value() default 0; - -} diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java new file mode 100644 index 000000000000..20db30944e51 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -0,0 +1,18 @@ +package org.ray.api.options; + +import java.util.Map; + +/** + * The options for creating actor. + */ +public class ActorCreationOptions extends BaseTaskOptions { + + public ActorCreationOptions() { + super(); + } + + public ActorCreationOptions(Map resources) { + super(resources); + } + +} diff --git a/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java new file mode 100644 index 000000000000..65494d532a68 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/BaseTaskOptions.java @@ -0,0 +1,20 @@ +package org.ray.api.options; + +import java.util.HashMap; +import java.util.Map; + +/** + * The options class for RayCall or ActorCreation. + */ +public abstract class BaseTaskOptions { + public Map resources; + + public BaseTaskOptions() { + resources = new HashMap<>(); + } + + public BaseTaskOptions(Map resources) { + this.resources = resources; + } + +} diff --git a/java/api/src/main/java/org/ray/api/options/CallOptions.java b/java/api/src/main/java/org/ray/api/options/CallOptions.java new file mode 100644 index 000000000000..84adfc122e04 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/options/CallOptions.java @@ -0,0 +1,18 @@ +package org.ray.api.options; + +import java.util.Map; + +/** + * The options for RayCall. + */ +public class CallOptions extends BaseTaskOptions { + + public CallOptions() { + super(); + } + + public CallOptions(Map resources) { + super(resources); + } + +} diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index d609d4de593d..7c12c3543c04 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -6,6 +6,9 @@ import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; /** * Base interface of a Ray runtime. @@ -65,9 +68,10 @@ public interface RayRuntime { * * @param func The remote function to run. * @param args The arguments of the remote function. + * @param options The options for this call. * @return The result object. */ - RayObject call(RayFunc func, Object[] args); + RayObject call(RayFunc func, Object[] args, CallOptions options); /** * Invoke a remote function on an actor. @@ -85,7 +89,9 @@ public interface RayRuntime { * @param actorFactoryFunc A remote function whose return value is the actor object. * @param args The arguments for the remote function. * @param The type of the actor object. + * @param options The options for creating actor. * @return A handle to the actor. */ - RayActor createActor(RayFunc actorFactoryFunc, Object[] args); + RayActor createActor(RayFunc actorFactoryFunc, Object[] args, + ActorCreationOptions options); } diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 330dbe365f15..022b61a33591 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -12,6 +12,9 @@ import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.BaseTaskOptions; +import org.ray.api.options.CallOptions; import org.ray.api.runtime.RayRuntime; import org.ray.runtime.config.RayConfig; import org.ray.runtime.functionmanager.FunctionManager; @@ -186,8 +189,8 @@ public WaitResult wait(List> waitList, int numReturns, int t } @Override - public RayObject call(RayFunc func, Object[] args) { - TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false); + public RayObject call(RayFunc func, Object[] args, CallOptions options) { + TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false, options); rayletClient.submitTask(spec); return new RayObjectImpl(spec.returnIds[0]); } @@ -198,7 +201,7 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) { throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName()); } RayActorImpl actorImpl = (RayActorImpl)actor; - TaskSpec spec = createTaskSpec(func, actorImpl, args, false); + TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null); spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor()); actorImpl.setTaskCursor(spec.returnIds[1]); rayletClient.submitTask(spec); @@ -207,8 +210,10 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) { @Override @SuppressWarnings("unchecked") - public RayActor createActor(RayFunc actorFactoryFunc, Object[] args) { - TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, args, true); + public RayActor createActor(RayFunc actorFactoryFunc, + Object[] args, ActorCreationOptions options) { + TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, + args, true, options); RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); @@ -236,11 +241,10 @@ private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { * @return A TaskSpec object. */ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, - boolean isActorCreationTask) { + boolean isActorCreationTask, BaseTaskOptions taskOptions) { final TaskSpec current = workerContext.getCurrentTask(); UniqueId taskId = rayletClient.generateTaskId(current.driverId, - current.taskId, - workerContext.nextCallIndex()); + current.taskId, workerContext.nextCallIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; UniqueId[] returnIds = genReturnIds(taskId, numReturns); @@ -249,6 +253,18 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, actorCreationId = returnIds[0]; } + Map resources; + if (null == taskOptions) { + resources = new HashMap<>(); + } else { + resources = new HashMap<>(taskOptions.resources); + } + + if (!resources.containsKey(ResourceUtil.CPU_LITERAL) + && !resources.containsKey(ResourceUtil.CPU_LITERAL.toLowerCase())) { + resources.put(ResourceUtil.CPU_LITERAL, 0.0); + } + RayFunction rayFunction = functionManager.getFunction(current.driverId, func); return new TaskSpec( current.driverId, @@ -261,7 +277,7 @@ private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args, actor.increaseTaskCounter(), ArgumentsBuilder.wrap(args), returnIds, - ResourceUtil.getResourcesMapFromArray(rayFunction.getRayRemoteAnnotation()), + resources, rayFunction.getFunctionDescriptor() ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java index 98cc43631242..4863ca5d13c1 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/ResourceUtil.java @@ -2,59 +2,11 @@ import java.util.HashMap; import java.util.Map; -import org.ray.api.annotation.RayRemote; -import org.ray.api.annotation.ResourceItem; public class ResourceUtil { public static final String CPU_LITERAL = "CPU"; public static final String GPU_LITERAL = "GPU"; - /** - * Convert the array that contains resource items to a map. - * - * @param remoteAnnotation The RayRemote annotation that contains the resource items. - * @return The map whose key represents the resource name - * and the value represents the resource quantity. - */ - public static Map getResourcesMapFromArray(RayRemote remoteAnnotation) { - Map resourceMap = new HashMap<>(); - if (remoteAnnotation != null) { - for (ResourceItem item : remoteAnnotation.resources()) { - if (!item.name().isEmpty()) { - resourceMap.put(item.name(), item.value()); - } - } - } - if (!resourceMap.containsKey(CPU_LITERAL)) { - resourceMap.put(CPU_LITERAL, 0.0); - } - return resourceMap; - } - - /** - * Convert the resources map to a format string. - * - * @param resources The resource map to be Converted. - * @return The format resources string, like "{CPU:4, GPU:0}". - */ - public static String getResourcesFromatStringFromMap(Map resources) { - if (resources == null) { - return "{}"; - } - StringBuilder builder = new StringBuilder(); - builder.append("{"); - int count = 1; - for (Map.Entry entry : resources.entrySet()) { - builder.append(entry.getKey()).append(":").append(entry.getValue()); - count++; - if (count != resources.size()) { - builder.append(", "); - } - } - builder.append("}"); - return builder.toString(); - } - /** * Convert resources map to a string that is used * for the command line argument of starting raylet. @@ -99,7 +51,7 @@ public static Map getResourcesMapFromString(String resources) String[] resourcePair = trimItem.split(":"); if (resourcePair.length != 2) { - throw new IllegalArgumentException("Format of static resurces configure is invalid."); + throw new IllegalArgumentException("Format of static resources configure is invalid."); } final String resourceName = resourcePair[0].trim(); diff --git a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java index 10ffc3488f28..82fdf6b7f99e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/generator/RayCallGenerator.java @@ -21,7 +21,17 @@ private String build() { newLine(""); newLine("package org.ray.api;"); newLine(""); - newLine("import org.ray.api.function.*;"); + newLine("import org.ray.api.function.RayFunc;"); + newLine("import org.ray.api.function.RayFunc0;"); + newLine("import org.ray.api.function.RayFunc1;"); + newLine("import org.ray.api.function.RayFunc2;"); + newLine("import org.ray.api.function.RayFunc3;"); + newLine("import org.ray.api.function.RayFunc4;"); + newLine("import org.ray.api.function.RayFunc5;"); + newLine("import org.ray.api.function.RayFunc6;"); + newLine("import org.ray.api.options.ActorCreationOptions;"); + newLine("import org.ray.api.options.BaseTaskOptions;"); + newLine("import org.ray.api.options.CallOptions;"); newLine(""); newLine("/**"); @@ -33,19 +43,21 @@ private String build() { newLine(1, "// Methods for remote function invocation."); newLine(1, "// ======================================="); for (int i = 0; i <= MAX_PARAMETERS; i++) { - buildCalls(i, false, false); + buildCalls(i, false, false, false); + buildCalls(i, false, false, true); } newLine(1, "// ==========================================="); newLine(1, "// Methods for remote actor method invocation."); newLine(1, "// ==========================================="); for (int i = 0; i <= MAX_PARAMETERS - 1; i++) { - buildCalls(i, true, false); + buildCalls(i, true, false, false); } newLine(1, "// ==========================="); newLine(1, "// Methods for actor creation."); newLine(1, "// ==========================="); for (int i = 0; i <= MAX_PARAMETERS; i++) { - buildCalls(i, false, true); + buildCalls(i, false, true, false); + buildCalls(i, false, true, true); } newLine("}"); return sb.toString(); @@ -57,7 +69,8 @@ private String build() { * @param forActor build actor api when true, otherwise build task api. * @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`. */ - private void buildCalls(int numParameters, boolean forActor, boolean forActorCreation) { + private void buildCalls(int numParameters, boolean forActor, + boolean forActorCreation, boolean hasOptionsParam) { String genericTypes = ""; String argList = ""; for (int i = 0; i < numParameters; i++) { @@ -82,18 +95,36 @@ private void buildCalls(int numParameters, boolean forActor, boolean forActorCre paramPrefix += ", "; } + String optionsParam; + if (hasOptionsParam) { + optionsParam = forActorCreation ? ", ActorCreationOptions options" : ", CallOptions options"; + } else { + optionsParam = ""; + } + + String optionsArg; + if (forActor) { + optionsArg = ""; + } else { + if (hasOptionsParam) { + optionsArg = ", options"; + } else { + optionsArg = ", null"; + } + } + String returnType = !forActorCreation ? "RayObject" : "RayActor"; String funcName = !forActorCreation ? "call" : "createActor"; String funcArgs = !forActor ? "f, args" : "f, actor, args"; for (String param : generateParameters(0, numParameters)) { // method signature newLine(1, String.format( - "public static <%s> %s %s(%s) {", - genericTypes, returnType, funcName, paramPrefix + param + "public static <%s> %s %s(%s%s) {", + genericTypes, returnType, funcName, paramPrefix + param, optionsParam )); // method body newLine(2, String.format("Object[] args = new Object[]{%s};", argList)); - newLine(2, String.format("return Ray.internal().%s(%s);", funcName, funcArgs)); + newLine(2, String.format("return Ray.internal().%s(%s%s);", funcName, funcArgs, optionsArg)); newLine(1, "}"); } } diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index 69d0f57a7570..e185a5f19a89 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -1,6 +1,8 @@ package org.ray.api.test; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import jdk.nashorn.internal.ir.annotations.Immutable; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,7 +11,8 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; -import org.ray.api.annotation.ResourceItem; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; /** * Resources Management Test. @@ -17,36 +20,13 @@ @RunWith(MyRunner.class) public class ResourcesManagementTest { - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), - @ResourceItem(name = "GPU", value = 0)}) - public static Integer echo1(Integer number) { + @RayRemote + public static Integer echo(Integer number) { return number; } - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), - @ResourceItem(name = "GPU", value = 2)}) - public static Integer echo2(Integer number) { - return number; - } - - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 2), - @ResourceItem(name = "GPU", value = 0)}) - public static class Echo1 { - public Integer echo(Integer number) { - return number; - } - } - - @RayRemote(resources = {@ResourceItem(name = "CPU", value = 8), - @ResourceItem(name = "GPU", value = 0)}) - public static class Echo2 { - public Integer echo(Integer number) { - return number; - } - } - - @RayRemote(resources = {@ResourceItem(name = "RES-A", value = 4)}) - public static class Echo3 { + @RayRemote + public static class Echo { public Integer echo(Integer number) { return number; } @@ -54,12 +34,18 @@ public Integer echo(Integer number) { @Test public void testMethods() { + CallOptions callOptions1 = new CallOptions(ImmutableMap.of("CPU", 4.0, "GPU", 0.0)); + // This is a case that can satisfy required resources. - RayObject result1 = Ray.call(ResourcesManagementTest::echo1, 100); + // The static resources for test are "CPU:4,RES-A:4". + RayObject result1 = Ray.call(ResourcesManagementTest::echo, 100, callOptions1); Assert.assertEquals(100, (int) result1.get()); + CallOptions callOptions2 = new CallOptions(ImmutableMap.of("CPU", 4.0, "GPU", 2.0)); + // This is a case that can't satisfy required resources. - final RayObject result2 = Ray.call(ResourcesManagementTest::echo2, 200); + // The static resources for test are "CPU:4,RES-A:4". + final RayObject result2 = Ray.call(ResourcesManagementTest::echo, 200, callOptions2); WaitResult waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000); Assert.assertEquals(0, waitResult.getReady().size()); @@ -68,28 +54,29 @@ public void testMethods() { @Test public void testActors() { + + ActorCreationOptions actorCreationOptions1 = + new ActorCreationOptions(ImmutableMap.of("CPU", 2.0, "GPU", 0.0)); + // This is a case that can satisfy required resources. - RayActor echo1 = Ray.createActor(Echo1::new); - final RayObject result1 = Ray.call(Echo1::echo, echo1, 100); + // The static resources for test are "CPU:4,RES-A:4". + RayActor echo1 = Ray.createActor(Echo::new, actorCreationOptions1); + final RayObject result1 = Ray.call(Echo::echo, echo1, 100); Assert.assertEquals(100, (int) result1.get()); // This is a case that can't satisfy required resources. - RayActor echo2 = Ray.createActor(Echo2::new); - final RayObject result2 = Ray.call(Echo2::echo, echo2, 100); + // The static resources for test are "CPU:4,RES-A:4". + ActorCreationOptions actorCreationOptions2 = + new ActorCreationOptions(ImmutableMap.of("CPU", 8.0, "GPU", 0.0)); + + RayActor echo2 = + Ray.createActor(Echo::new, actorCreationOptions2); + final RayObject result2 = Ray.call(Echo::echo, echo2, 100); WaitResult waitResult = Ray.wait(ImmutableList.of(result2), 1, 1000); Assert.assertEquals(0, waitResult.getReady().size()); Assert.assertEquals(1, waitResult.getUnready().size()); } - @Test - public void testActorAndMemberMethods() { - // Note(qwang): This case depends on the following line. - // https://github.com/ray-project/ray/blob/master/java/test/src/main/java/org/ray/api/test/TestListener.java#L13 - // If we change the static resources configuration item, this case may not pass. - // Then we should change this case too. - RayActor echo3 = Ray.createActor(Echo3::new); - Assert.assertEquals(100, (int) Ray.call(Echo3::echo, echo3, 100).get()); - } } From 9a5c273db7168b3a12b93203e48ac1e0285b552f Mon Sep 17 00:00:00 2001 From: bibabolynn <1018527906@qq.com> Date: Fri, 19 Oct 2018 21:43:42 +0800 Subject: [PATCH 041/215] [java] fix check exception type (#3093) ## What do these changes do? remove TaskExecutionException, use RayException instead ## Related issue number --- .../org/ray/runtime/AbstractRayRuntime.java | 6 +++--- .../runtime/objectstore/ObjectStoreProxy.java | 18 +++++++++--------- .../util/exception/TaskExecutionException.java | 15 --------------- 3 files changed, 12 insertions(+), 27 deletions(-) delete mode 100644 java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 022b61a33591..e80581c1f489 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -10,6 +10,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; @@ -26,7 +27,6 @@ import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.UniqueIdUtil; -import org.ray.runtime.util.exception.TaskExecutionException; import org.ray.runtime.util.logger.RayLog; /** @@ -80,7 +80,7 @@ public void put(UniqueId objectId, T obj) { } @Override - public T get(UniqueId objectId) throws TaskExecutionException { + public T get(UniqueId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); return ret.get(0); } @@ -149,7 +149,7 @@ public List get(List objectIds) { } return finalRet; - } catch (TaskExecutionException e) { + } catch (RayException e) { RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get with Exception", e); throw e; diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 3a33d862e4b1..5f8221ff6f02 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -4,11 +4,11 @@ import java.util.List; import org.apache.arrow.plasma.ObjectStoreLink; import org.apache.commons.lang3.tuple.Pair; +import org.ray.api.exception.RayException; import org.ray.api.id.UniqueId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; import org.ray.runtime.util.UniqueIdUtil; -import org.ray.runtime.util.exception.TaskExecutionException; /** * Object store proxy, which handles serialization and deserialization, and utilize a {@code @@ -27,18 +27,18 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) { } public Pair get(UniqueId objectId, boolean isMetadata) - throws TaskExecutionException { + throws RayException { return get(objectId, GET_TIMEOUT_MS, isMetadata); } public Pair get(UniqueId id, int timeoutMs, boolean isMetadata) - throws TaskExecutionException { + throws RayException { byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata); if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); store.release(id.getBytes()); - if (t instanceof TaskExecutionException) { - throw (TaskExecutionException) t; + if (t instanceof RayException) { + throw (RayException) t; } return Pair.of(t, GetStatus.SUCCESS); } else { @@ -47,12 +47,12 @@ public Pair get(UniqueId id, int timeoutMs, boolean isMetadata } public List> get(List objectIds, boolean isMetadata) - throws TaskExecutionException { + throws RayException { return get(objectIds, GET_TIMEOUT_MS, isMetadata); } public List> get(List ids, int timeoutMs, boolean isMetadata) - throws TaskExecutionException { + throws RayException { List objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata); List> ret = new ArrayList<>(); for (int i = 0; i < objs.size(); i++) { @@ -60,8 +60,8 @@ public List> get(List ids, int timeoutMs, boole if (obj != null) { T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader()); store.release(ids.get(i).getBytes()); - if (t instanceof TaskExecutionException) { - throw (TaskExecutionException) t; + if (t instanceof RayException) { + throw (RayException) t; } ret.add(Pair.of(t, GetStatus.SUCCESS)); } else { diff --git a/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java b/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java deleted file mode 100644 index 99bc0912e1d0..000000000000 --- a/java/runtime/src/main/java/org/ray/runtime/util/exception/TaskExecutionException.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.ray.runtime.util.exception; - -/** - * An exception which is thrown when a ray task encounters an error when executing. - */ -public class TaskExecutionException extends RuntimeException { - - public TaskExecutionException(Throwable cause) { - super(cause); - } - - public TaskExecutionException(String message, Throwable cause) { - super(message, cause); - } -} From 9a2b5333ef8880efc7cafbac7fcead22222d6dd2 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 19 Oct 2018 12:15:22 -0700 Subject: [PATCH 042/215] Add links for latest Python 3.7 wheels to documentation. (#3091) --- doc/source/installation.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/source/installation.rst b/doc/source/installation.rst index ebee27f9f028..0a5f435f4260 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -23,6 +23,7 @@ features but may be subject to more bugs. To install these wheels, run the follo =================== =================== Linux MacOS =================== =================== +`Linux Python 3.7`_ `MacOS Python 3.7`_ `Linux Python 3.6`_ `MacOS Python 3.6`_ `Linux Python 3.5`_ `MacOS Python 3.5`_ `Linux Python 3.4`_ `MacOS Python 3.4`_ @@ -30,10 +31,12 @@ features but may be subject to more bugs. To install these wheels, run the follo =================== =================== +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp37-cp37m-manylinux1_x86_64.whl .. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-manylinux1_x86_64.whl .. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-manylinux1_x86_64.whl .. _`Linux Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp34-cp34m-manylinux1_x86_64.whl .. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp37-cp37m-macosx_10_6_intel.whl .. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp36-cp36m-macosx_10_6_intel.whl .. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp35-cp35m-macosx_10_6_intel.whl .. _`MacOS Python 3.4`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.5.3-cp34-cp34m-macosx_10_6_intel.whl From 59901a88a0122bc2baeadf5056ec9a3b32b44fae Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 20 Oct 2018 15:21:22 -0700 Subject: [PATCH 043/215] [rllib] Native support for Dict and Tuple spaces; fix Tuple action spaces; add prev a, r to LSTM (#3051) --- doc/source/rllib-concepts.rst | 2 +- doc/source/rllib-models.rst | 49 +++- .../rllib/agents/a3c/a3c_tf_policy_graph.py | 13 +- python/ray/rllib/agents/agent.py | 5 +- python/ray/rllib/agents/ars/ars.py | 6 +- python/ray/rllib/agents/ars/policies.py | 7 +- .../rllib/agents/ddpg/ddpg_policy_graph.py | 29 +- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 17 +- python/ray/rllib/agents/es/es.py | 7 +- python/ray/rllib/agents/es/policies.py | 10 +- .../agents/impala/vtrace_policy_graph.py | 19 +- python/ray/rllib/agents/pg/pg_policy_graph.py | 14 +- python/ray/rllib/agents/ppo/ppo.py | 2 +- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 26 +- python/ray/rllib/env/async_vector_env.py | 21 +- python/ray/rllib/env/serving_env.py | 14 +- python/ray/rllib/env/vector_env.py | 2 + python/ray/rllib/evaluation/episode.py | 56 +++- .../ray/rllib/evaluation/policy_evaluator.py | 24 +- python/ray/rllib/evaluation/policy_graph.py | 8 + python/ray/rllib/evaluation/sampler.py | 42 ++- .../ray/rllib/evaluation/tf_policy_graph.py | 17 +- .../rllib/evaluation/torch_policy_graph.py | 2 + python/ray/rllib/examples/carla/models.py | 1 + python/ray/rllib/examples/cartpole_lstm.py | 14 +- python/ray/rllib/models/action_dist.py | 10 +- python/ray/rllib/models/catalog.py | 87 ++++-- python/ray/rllib/models/fcnet.py | 6 + python/ray/rllib/models/lstm.py | 23 +- python/ray/rllib/models/model.py | 118 ++++++++- python/ray/rllib/models/preprocessors.py | 93 +++++-- python/ray/rllib/models/visionnet.py | 3 +- python/ray/rllib/test/test_catalog.py | 24 +- python/ray/rllib/test/test_multi_agent_env.py | 4 + python/ray/rllib/test/test_nested_spaces.py | 249 ++++++++++++++++++ .../ray/rllib/test/test_policy_evaluator.py | 22 +- .../ray/rllib/test/test_supported_spaces.py | 53 ++-- python/ray/rllib/utils/tf_run_builder.py | 3 +- test/jenkins_tests/run_multi_node_tests.sh | 6 + 39 files changed, 922 insertions(+), 186 deletions(-) create mode 100644 python/ray/rllib/test/test_nested_spaces.py diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index f752279cb58d..68c160c912b0 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -17,7 +17,7 @@ Policy Evaluation Given an environment and policy graph, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. -You can also use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvalutor.as_remote()``). +You can also use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvaluator.as_remote()``). Policy Optimization ------------------- diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 6efc4abb8e7a..a2a9233ef3f5 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -15,7 +15,7 @@ RLlib picks default models based on a simple heuristic: a `vision network `__. More generally, RLlib supports the use of recurrent models for its policy gradient algorithms (A3C, PPO, PG, IMPALA), and RNN support is built into its policy evaluation utilities. -For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple observations flattened (there isn't native tuple support yet, but you can reshape the flattened observation in a custom model). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. +For preprocessors, RLlib tries to pick one of its built-in preprocessor based on the environment's observation space. Discrete observations are one-hot encoded, Atari observations downscaled, and Tuple and Dict observations flattened (these are unflattened and accessible via the ``input_dict`` parameter in custom models). Note that for Atari, RLlib defaults to using the `DeepMind preprocessors `__, which are also used by the OpenAI baselines library. Built-in Model Parameters ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -30,7 +30,7 @@ The following is a list of the built-in model hyperparameters: Custom Models ------------- -Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers`` method. This method takes in a tensor input (observation), and returns a feature layer and float vector of the specified output size. The model can then be registered and used in place of a built-in model: +Custom models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``), and returns a feature layer and float vector of the specified output size. The model can then be registered and used in place of a built-in model: .. code-block:: python @@ -39,9 +39,38 @@ Custom models should subclass the common RLlib `model class >> print(input_dict) + {'prev_actions': , + 'prev_rewards': , + 'obs': OrderedDict([ + ('sensors', OrderedDict([ + ('front_cam', [ + , + ]), + ('position', ), + ('velocity', )]))])} + """ + + layer1 = slim.fully_connected(input_dict["obs"], 64, ...) + layer2 = slim.fully_connected(layer1, 64, ...) ... return layerN, layerN_minus_1 @@ -55,12 +84,12 @@ Custom models should subclass the common RLlib `model class `__ and associated `training scripts `__. The ``CarlaModel`` class defined there operates over a composite (Tuple) observation space including both images and scalar measurements. +For a full example of a custom model in code, see the `Carla RLlib model `__ and associated `training scripts `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. Custom Preprocessors -------------------- -Similarly, custom preprocessors should subclass the RLlib `preprocessor class `__ and be registered in the model catalog: +Similarly, custom preprocessors should subclass the RLlib `preprocessor class `__ and be registered in the model catalog. Note that you can alternatively use `gym wrapper classes `__ around your environment instead of preprocessors. .. code-block:: python @@ -69,8 +98,8 @@ Similarly, custom preprocessors should subclass the RLlib `preprocessor class rllib.AsyncVectorEnv rllib.ServingEnv => rllib.AsyncVectorEnv + Attributes: + action_space (gym.Space): Action space. This must be defined for + single-agent envs. Multi-agent envs can set this to None. + observation_space (gym.Space): Observation space. This must be defined + for single-agent envs. Multi-agent envs can set this to None. + Examples: >>> env = MyAsyncVectorEnv() >>> obs, rewards, dones, infos, off_policy_actions = env.poll() @@ -142,8 +148,14 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID): class _ServingEnvToAsync(AsyncVectorEnv): """Internal adapter of ServingEnv to AsyncVectorEnv.""" - def __init__(self, serving_env): + def __init__(self, serving_env, preprocessor=None): self.serving_env = serving_env + self.prep = preprocessor + self.action_space = serving_env.action_space + if preprocessor: + self.observation_space = preprocessor.observation_space + else: + self.observation_space = serving_env.observation_space serving_env.start() def poll(self): @@ -168,7 +180,10 @@ def _poll(self): if episode.cur_done: del self.serving_env._episodes[eid] if data: - all_obs[eid] = data["obs"] + if self.prep: + all_obs[eid] = self.prep.transform(data["obs"]) + else: + all_obs[eid] = data["obs"] all_rewards[eid] = data["reward"] all_dones[eid] = data["done"] all_infos[eid] = data["info"] @@ -196,6 +211,8 @@ class _VectorEnvToAsync(AsyncVectorEnv): def __init__(self, vector_env): self.vector_env = vector_env + self.action_space = vector_env.action_space + self.observation_space = vector_env.observation_space self.num_envs = vector_env.num_envs self.new_obs = self.vector_env.vector_reset() self.cur_rewards = [None for _ in range(self.num_envs)] diff --git a/python/ray/rllib/env/serving_env.py b/python/ray/rllib/env/serving_env.py index 0c1e3ec0dbfe..528cae266d50 100644 --- a/python/ray/rllib/env/serving_env.py +++ b/python/ray/rllib/env/serving_env.py @@ -25,6 +25,10 @@ class ServingEnv(threading.Thread): This env is thread-safe, but individual episodes must be executed serially. + Attributes: + action_space (gym.Space): Action space. + observation_space (gym.Space): Observation space. + Examples: >>> register_env("my_env", lambda config: YourServingEnv(config)) >>> agent = DQNAgent(env="my_env") @@ -57,10 +61,12 @@ def run(self): """Override this to implement the run loop. Your loop should continuously: - 1. Call self.start_episode() - 2. Call self.get_action() or self.log_action() - 3. Call self.log_returns() - 4. Call self.end_episode() + 1. Call self.start_episode(episode_id) + 2. Call self.get_action(episode_id, obs) + -or- + self.log_action(episode_id, obs, action) + 3. Call self.log_returns(episode_id, reward) + 4. Call self.end_episode(episode_id, obs) 5. Wait if nothing to do. Multiple episodes may be started at the same time. diff --git a/python/ray/rllib/env/vector_env.py b/python/ray/rllib/env/vector_env.py index 7fb5b1605543..8d2289cf4144 100644 --- a/python/ray/rllib/env/vector_env.py +++ b/python/ray/rllib/env/vector_env.py @@ -69,6 +69,8 @@ def __init__(self, make_env, existing_envs, num_envs): self.num_envs = num_envs while len(self.envs) < self.num_envs: self.envs.append(self.make_env(len(self.envs))) + self.action_space = self.envs[0].action_space + self.observation_space = self.envs[0].observation_space def vector_reset(self): return [e.reset() for e in self.envs] diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index fc99d79fbb04..ebd6ea784203 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -54,6 +54,8 @@ def __init__(self, policies, policy_mapping_fn, batch_builder_factory, self._agent_to_last_obs = {} self._agent_to_last_action = {} self._agent_to_last_pi_info = {} + self._agent_to_prev_action = {} + self._agent_reward_history = defaultdict(list) def policy_for(self, agent_id): """Returns the policy graph for the specified agent. @@ -72,19 +74,33 @@ def last_observation_for(self, agent_id): return self._agent_to_last_obs.get(agent_id) def last_action_for(self, agent_id): - """Returns the last action for the specified agent.""" - - action = self._agent_to_last_action[agent_id] - # Concatenate tuple actions - if isinstance(action, list): - expanded = [] - for a in action: - if len(a.shape) == 1: - expanded.append(np.expand_dims(a, 1)) - else: - expanded.append(a) - action = np.concatenate(expanded, axis=1).flatten() - return action + """Returns the last action for the specified agent, or zeros.""" + + if agent_id in self._agent_to_last_action: + return _flatten_action(self._agent_to_last_action[agent_id]) + else: + policy = self._policies[self.policy_for(agent_id)] + flat = _flatten_action(policy.action_space.sample()) + return np.zeros_like(flat) + + def prev_action_for(self, agent_id): + """Returns the previous action for the specified agent.""" + + if agent_id in self._agent_to_prev_action: + return _flatten_action(self._agent_to_prev_action[agent_id]) + else: + # We're at t=0, so return all zeros. + return np.zeros_like(self.last_action_for(agent_id)) + + def prev_reward_for(self, agent_id): + """Returns the previous reward for the specified agent.""" + + history = self._agent_reward_history[agent_id] + if len(history) >= 2: + return history[-2] + else: + # We're at t=0, so there is no previous reward, just return zero. + return 0.0 def rnn_state_for(self, agent_id): """Returns the last RNN state for the specified agent.""" @@ -105,6 +121,7 @@ def _add_agent_rewards(self, reward_dict): self.agent_rewards[agent_id, self.policy_for(agent_id)] += reward self.total_reward += reward + self._agent_reward_history[agent_id].append(reward) def _set_rnn_state(self, agent_id, rnn_state): self._agent_to_rnn_state[agent_id] = rnn_state @@ -117,3 +134,16 @@ def _set_last_action(self, agent_id, action): def _set_last_pi_info(self, agent_id, pi_info): self._agent_to_last_pi_info[agent_id] = pi_info + + +def _flatten_action(action): + # Concatenate tuple actions + if isinstance(action, list) or isinstance(action, tuple): + expanded = [] + for a in action: + if not hasattr(a, "shape") or len(a.shape) == 0: + expanded.append(np.expand_dims(a, 1)) + else: + expanded.append(a) + action = np.concatenate(expanded, axis=0).flatten() + return action diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 548b65806bad..db88eb759df7 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -11,8 +11,6 @@ from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.serving_env import ServingEnv -from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.interface import EvaluatorInterface from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \ @@ -179,11 +177,15 @@ def __init__(self, self.compress_observations = compress_observations self.env = env_creator(env_context) - if isinstance(self.env, VectorEnv) or \ - isinstance(self.env, ServingEnv) or \ - isinstance(self.env, MultiAgentEnv) or \ + if isinstance(self.env, MultiAgentEnv) or \ isinstance(self.env, AsyncVectorEnv): + if model_config.get("custom_preprocessor"): + raise ValueError( + "Custom preprocessors are not supported for env types " + "MultiAgentEnv and AsyncVectorEnv. Please preprocess " + "observations in your env directly.") + def wrap(env): return env # we can't auto-wrap these env types elif is_atari(self.env) and \ @@ -294,6 +296,18 @@ def _build_policy_map(self, policy_dict, policy_config): merged_conf = policy_config.copy() merged_conf.update(conf) with tf.variable_scope(name): + if isinstance(obs_space, gym.spaces.Dict): + raise ValueError( + "Found raw Dict space as input to policy graph. " + "Please preprocess your environment observations " + "with DictFlatteningPreprocessor and set the " + "obs space to `preprocessor.observation_space`.") + elif isinstance(obs_space, gym.spaces.Tuple): + raise ValueError( + "Found raw Tuple space as input to policy graph. " + "Please preprocess your environment observations " + "with TupleFlatteningPreprocessor and set the " + "obs space to `preprocessor.observation_space`.") policy_map[name] = cls(obs_space, act_space, merged_conf) return policy_map diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index 925fa70aa154..b2d154f488d7 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -40,6 +40,8 @@ class you pass into PolicyEvaluator will be constructed with def compute_actions(self, obs_batch, state_batches, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): """Compute actions for the current policy. @@ -47,6 +49,8 @@ def compute_actions(self, Arguments: obs_batch (np.ndarray): batch of observations state_batches (list): list of RNN state input batches, if any + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards is_training (bool): whether we are training the policy episodes (list): MultiAgentEpisode for each obs in obs_batch. This provides access to all of the internal episode state, @@ -65,6 +69,8 @@ def compute_actions(self, def compute_single_action(self, obs, state, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episode=None): """Unbatched version of compute_actions. @@ -72,6 +78,8 @@ def compute_single_action(self, Arguments: obs (obj): single observation state_batches (list): list of RNN state inputs, if any + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards is_training (bool): whether we are training the policy episode (MultiAgentEpisode): this provides access to all of the internal episode state, which may be useful for model-based or diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 64999d3638ee..2f90dcbabcc4 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -3,22 +3,25 @@ from __future__ import print_function from collections import defaultdict, namedtuple +import numpy as np import six.moves.queue as queue import threading -from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder, \ MultiAgentBatch from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.env.async_vector_env import AsyncVectorEnv from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv +from ray.rllib.models.action_dist import TupleActions from ray.rllib.utils.tf_run_builder import TFRunBuilder RolloutMetrics = namedtuple( "RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"]) -PolicyEvalData = namedtuple("PolicyEvalData", - ["env_id", "agent_id", "obs", "rnn_state"]) +PolicyEvalData = namedtuple( + "PolicyEvalData", + ["env_id", "agent_id", "obs", "rnn_state", "prev_action", "prev_reward"]) class SyncSampler(object): @@ -281,7 +284,9 @@ def new_episode(): if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id))) + episode.rnn_state_for(agent_id), + episode.last_action_for(agent_id), + rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) @@ -297,6 +302,8 @@ def new_episode(): obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], + prev_actions=episode.prev_action_for(agent_id), + prev_rewards=episode.prev_reward_for(agent_id), dones=agent_done, infos=infos[env_id][agent_id], new_obs=filtered_obs, @@ -326,12 +333,17 @@ def new_episode(): episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) + policy = _get_or_raise(policies, policy_id) filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( - PolicyEvalData(env_id, agent_id, filtered_obs, - episode.rnn_state_for(agent_id))) + PolicyEvalData( + env_id, agent_id, filtered_obs, + episode.rnn_state_for(agent_id), + np.zeros_like( + _flatten_action( + policy.action_space.sample())), 0.0)) # Batch eval policy actions if possible if tf_sess: @@ -350,11 +362,15 @@ def new_episode(): pending_fetches[policy_id] = policy.build_compute_actions( builder, [t.obs for t in eval_data], rnn_in, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data], is_training=True) else: eval_results[policy_id] = policy.compute_actions( [t.obs for t in eval_data], rnn_in, + prev_action_batch=[t.prev_action for t in eval_data], + prev_reward_batch=[t.prev_reward for t in eval_data], is_training=True, episodes=[active_episodes[t.env_id] for t in eval_data]) if builder: @@ -374,6 +390,7 @@ def new_episode(): for f_i, column in enumerate(rnn_out_cols): pi_info_cols["state_out_{}".format(f_i)] = column # Save output rows + actions = _unbatch_tuple_actions(actions) for i, action in enumerate(actions): env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id @@ -413,6 +430,19 @@ def _fetch_atari_metrics(async_vector_env): return atari_out +def _unbatch_tuple_actions(action_batch): + # convert list of batches -> batch of lists + if isinstance(action_batch, TupleActions): + out = [] + for j in range(len(action_batch.batches[0])): + out.append([ + action_batch.batches[i][j] + for i in range(len(action_batch.batches)) + ]) + return out + return action_batch + + def _to_column_format(rnn_state_rows): num_cols = len(rnn_state_rows[0]) return [[row[i] for row in rnn_state_rows] for i in range(num_cols)] diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 09a84981ee83..a7b34c2ce0eb 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -46,6 +46,8 @@ def __init__(self, loss_inputs, state_inputs=None, state_outputs=None, + prev_action_input=None, + prev_reward_input=None, seq_lens=None, max_seq_len=20): """Initialize the policy graph. @@ -65,6 +67,8 @@ def __init__(self, and has shape [BATCH_SIZE, data...]. state_inputs (list): list of RNN state input Tensors. state_outputs (list): list of RNN state output Tensors. + prev_action_input (Tensor): placeholder for previous actions + prev_reward_input (Tensor): placeholder for previous rewards seq_lens (Tensor): placeholder for RNN sequence lengths, of shape [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See models/lstm.py for more information. @@ -75,6 +79,8 @@ def __init__(self, self.action_space = action_space self._sess = sess self._obs_input = obs_input + self._prev_action_input = prev_action_input + self._prev_reward_input = prev_reward_input self._sampler = action_sampler self._loss = loss self._loss_inputs = loss_inputs @@ -112,6 +118,8 @@ def build_compute_actions(self, builder, obs_batch, state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): state_batches = state_batches or [] @@ -121,6 +129,10 @@ def build_compute_actions(self, builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + if self._prev_action_input is not None and prev_action_batch: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) builder.add_feed_dict({self._is_training: is_training}) builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) fetches = builder.add_fetches([self._sampler] + self._state_outputs + @@ -130,11 +142,14 @@ def build_compute_actions(self, def compute_actions(self, obs_batch, state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): builder = TFRunBuilder(self._sess, "compute_actions") fetches = self.build_compute_actions(builder, obs_batch, state_batches, - is_training) + prev_action_batch, + prev_reward_batch, is_training) return builder.get(fetches) def _get_loss_inputs_dict(self, batch): diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 741357f3aa8d..cb990c36f8bf 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -70,6 +70,8 @@ def optimizer(self): def compute_actions(self, obs_batch, state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): if state_batches: diff --git a/python/ray/rllib/examples/carla/models.py b/python/ray/rllib/examples/carla/models.py index fd20cd0c000c..3f8cc0c5ba47 100644 --- a/python/ray/rllib/examples/carla/models.py +++ b/python/ray/rllib/examples/carla/models.py @@ -20,6 +20,7 @@ class CarlaModel(Model): further fully connected layers. """ + # TODO(ekl): use build_layers_v2 for native dict space support def _build_layers(self, inputs, num_outputs, options): # Parse options image_shape = options["custom_options"]["image_shape"] diff --git a/python/ray/rllib/examples/cartpole_lstm.py b/python/ray/rllib/examples/cartpole_lstm.py index 67fd35d28dcf..ddc89c47e3b3 100644 --- a/python/ray/rllib/examples/cartpole_lstm.py +++ b/python/ray/rllib/examples/cartpole_lstm.py @@ -14,6 +14,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=200) +parser.add_argument("--use-prev-action-reward", action="store_true") parser.add_argument("--run", type=str, default="PPO") @@ -183,10 +184,13 @@ def close(self): "stop": { "episode_reward_mean": args.stop }, - "config": dict(configs[args.run], **{ - "model": { - "use_lstm": True, - }, - }), + "config": dict( + configs[args.run], **{ + "model": { + "use_lstm": True, + "lstm_use_prev_action_reward": args. + use_prev_action_reward, + }, + }), } }) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index b0cfe4141af1..91b8d2fce21d 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -2,9 +2,11 @@ from __future__ import division from __future__ import print_function +from collections import namedtuple +import distutils.version + import tensorflow as tf import numpy as np -import distutils.version use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion("1.5.0")) @@ -225,4 +227,8 @@ def entropy(self): def sample(self): """Draw a sample from the action distribution.""" - return [[s.sample() for s in self.child_distributions]] + + return TupleActions([s.sample() for s in self.child_distributions]) + + +TupleActions = namedtuple("TupleActions", ["batches"]) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index d2038f55f888..1c55cb79e1e7 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -10,6 +10,9 @@ from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ _global_registry +from ray.rllib.env.async_vector_env import _ServingEnvToAsync +from ray.rllib.env.serving_env import ServingEnv +from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models.action_dist import ( Categorical, Deterministic, DiagGaussian, MultiActionDistribution, squash_to_range) @@ -41,6 +44,8 @@ "max_seq_len": 20, # Size of the LSTM cell "lstm_cell_size": 256, + # Whether to feed a_{t-1}, r_{t-1} to LSTM + "lstm_use_prev_action_reward": False, # == Atari == # Whether to enable framestack for Atari envs @@ -133,10 +138,6 @@ def get_action_placeholder(action_space): action_placeholder (Tensor): A placeholder for the actions """ - # TODO(ekl) are list spaces valid? - if isinstance(action_space, list): - action_space = gym.spaces.Tuple(action_space) - if isinstance(action_space, gym.spaces.Box): return tf.placeholder( tf.float32, shape=(None, action_space.shape[0]), name="action") @@ -160,11 +161,18 @@ def get_action_placeholder(action_space): " not supported".format(action_space)) @staticmethod - def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None): + def get_model(input_dict, + obs_space, + num_outputs, + options, + state_in=None, + seq_lens=None): """Returns a suitable model conforming to given input and output specs. Args: - inputs (Tensor): The input tensor to the model. + input_dict (dict): Dict of input tensors to the model, including + the observation under the "obs" key. + obs_space (Space): Observation space of the target gym env. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. state_in (list): Optional RNN state in tensors. @@ -174,34 +182,40 @@ def get_model(inputs, num_outputs, options, state_in=None, seq_lens=None): model (Model): Neural network model. """ + assert isinstance(input_dict, dict) options = options or MODEL_DEFAULTS - model = ModelCatalog._get_model(inputs, num_outputs, options, state_in, - seq_lens) + model = ModelCatalog._get_model(input_dict, obs_space, num_outputs, + options, state_in, seq_lens) if options.get("use_lstm"): - model = LSTM(model.last_layer, num_outputs, options, state_in, + copy = dict(input_dict) + copy["obs"] = model.last_layer + model = LSTM(copy, obs_space, num_outputs, options, state_in, seq_lens) return model @staticmethod - def _get_model(inputs, num_outputs, options, state_in, seq_lens): + def _get_model(input_dict, obs_space, num_outputs, options, state_in, + seq_lens): if options.get("custom_model"): model = options["custom_model"] print("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( - inputs, + input_dict, + obs_space, num_outputs, options, state_in=state_in, seq_lens=seq_lens) - obs_rank = len(inputs.shape) - 1 + obs_rank = len(input_dict["obs"].shape) - 1 if obs_rank > 1: - return VisionNetwork(inputs, num_outputs, options) + return VisionNetwork(input_dict, obs_space, num_outputs, options) - return FullyConnectedNetwork(inputs, num_outputs, options) + return FullyConnectedNetwork(input_dict, obs_space, num_outputs, + options) @staticmethod def get_torch_model(input_shape, num_outputs, options=None): @@ -243,7 +257,7 @@ def get_preprocessor(env, options=None): """Returns a suitable processor for the given environment. Args: - env (gym.Env): The gym environment to preprocess. + env (gym.Env|VectorEnv|ServingEnv): The environment to wrap. options (dict): Options to pass to the preprocessor. Returns: @@ -269,16 +283,23 @@ def get_preprocessor_as_wrapper(env, options=None): """Returns a preprocessor as a gym observation wrapper. Args: - env (gym.Env): The gym environment to wrap. + env (gym.Env|VectorEnv|ServingEnv): The environment to wrap. options (dict): Options to pass to the preprocessor. Returns: - wrapper (gym.ObservationWrapper): Preprocessor in wrapper form. + env (RLlib env): Wrapped environment """ options = options or MODEL_DEFAULTS preprocessor = ModelCatalog.get_preprocessor(env, options) - return _RLlibPreprocessorWrapper(env, preprocessor) + if isinstance(env, gym.Env): + return _RLlibPreprocessorWrapper(env, preprocessor) + elif isinstance(env, VectorEnv): + return _RLlibVectorPreprocessorWrapper(env, preprocessor) + elif isinstance(env, ServingEnv): + return _ServingEnvToAsync(env, preprocessor) + else: + raise ValueError("Don't know how to wrap {}".format(env)) @staticmethod def register_custom_preprocessor(preprocessor_name, preprocessor_class): @@ -314,10 +335,32 @@ class _RLlibPreprocessorWrapper(gym.ObservationWrapper): def __init__(self, env, preprocessor): super(_RLlibPreprocessorWrapper, self).__init__(env) self.preprocessor = preprocessor - - from gym.spaces.box import Box - self.observation_space = Box( - -1.0, 1.0, preprocessor.shape, dtype=np.float32) + self.observation_space = preprocessor.observation_space def observation(self, observation): return self.preprocessor.transform(observation) + + +class _RLlibVectorPreprocessorWrapper(VectorEnv): + """Preprocessing wrapper for vector envs.""" + + def __init__(self, env, preprocessor): + self.env = env + self.prep = preprocessor + self.action_space = env.action_space + self.observation_space = preprocessor.observation_space + self.num_envs = env.num_envs + + def vector_reset(self): + return [self.prep.transform(obs) for obs in self.env.vector_reset()] + + def reset_at(self, index): + return self.prep.transform(self.env.reset_at(index)) + + def vector_step(self, actions): + obs, rewards, dones, infos = self.env.vector_step(actions) + obs = [self.prep.transform(o) for o in obs] + return obs, rewards, dones, infos + + def get_unwrapped(self): + return self.env.get_unwrapped() diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index e703fb0a080d..5a759fd59ef8 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -13,6 +13,12 @@ class FullyConnectedNetwork(Model): """Generic fully connected network.""" def _build_layers(self, inputs, num_outputs, options): + """Process the flattened inputs. + + Note that dict inputs will be flattened into a vector. To define a + model that processes the components separately, use _build_layers_v2(). + """ + hiddens = options.get("fcnet_hiddens") activation = get_activation_fn(options.get("fcnet_activation")) diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index b8dea3ede95c..5f3bdc8b7d72 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -9,6 +9,10 @@ postprocessing, we dynamically pad the experience batches so that this reshaping is possible. +Note that this padding strategy only works out if we assume zero inputs don't +meaningfully affect the loss function. This happens to be true for all the +current algorithms: https://github.com/ray-project/ray/issues/2992 + See the add_time_dimension() and chop_into_sequences() functions below for more info. """ @@ -134,9 +138,24 @@ class LSTM(Model): self.seq_lens. See add_time_dimension() for more information. """ - def _build_layers(self, inputs, num_outputs, options): + def _build_layers_v2(self, input_dict, num_outputs, options): cell_size = options.get("lstm_cell_size") - last_layer = add_time_dimension(inputs, self.seq_lens) + if options.get("lstm_use_prev_action_reward"): + action_dim = int( + np.product( + input_dict["prev_actions"].get_shape().as_list()[1:])) + features = tf.concat( + [ + input_dict["obs"], + tf.reshape( + tf.cast(input_dict["prev_actions"], tf.float32), + [-1, action_dim]), + tf.reshape(input_dict["prev_rewards"], [-1, 1]), + ], + axis=1) + else: + features = input_dict["obs"] + last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 168f29c74625..fb8ea668799b 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -2,8 +2,13 @@ from __future__ import division from __future__ import print_function +from collections import OrderedDict + +import gym import tensorflow as tf +from ray.rllib.models.preprocessors import get_preprocessor + class Model(object): """Defines an abstract network model for use with RLlib. @@ -16,12 +21,12 @@ class Model(object): needs to further post-processing (e.g. Actor and Critic networks in A3C). Attributes: - inputs (Tensor): The input placeholder for this model, of shape - [BATCH_SIZE, ...]. + input_dict (dict): Dictionary of input tensors, including "obs", + "prev_action", "prev_reward". outputs (Tensor): The output vector of this model, of shape [BATCH_SIZE, num_outputs]. - last_layer (Tensor): The network layer right before the model output, - of shape [BATCH_SIZE, N]. + last_layer (Tensor): The feature layer right before the model output, + of shape [BATCH_SIZE, f]. state_init (list): List of initial recurrent state tensors (if any). state_in (list): List of input recurrent state tensors (if any). state_out (list): List of output recurrent state tensors (if any). @@ -38,12 +43,13 @@ class Model(object): """ def __init__(self, - inputs, + input_dict, + obs_space, num_outputs, options, state_in=None, seq_lens=None): - self.inputs = inputs + assert isinstance(input_dict, dict), input_dict # Default attribute values for the non-RNN case self.state_init = [] @@ -58,8 +64,26 @@ def __init__(self, if options.get("free_log_std"): assert num_outputs % 2 == 0 num_outputs = num_outputs // 2 - self.outputs, self.last_layer = self._build_layers( - inputs, num_outputs, options) + try: + self.outputs, self.last_layer = self._build_layers_v2( + _restore_original_dimensions(input_dict, obs_space), + num_outputs, options) + except NotImplementedError: + self.outputs, self.last_layer = self._build_layers( + input_dict["obs"], num_outputs, options) + + # Validate the output shape + try: + out = tf.convert_to_tensor(self.outputs) + shape = out.shape.as_list() + except Exception: + raise ValueError("Output is not a tensor: {}".format(self.outputs)) + else: + if len(shape) != 2 or shape[1] != num_outputs: + raise ValueError( + "Expected output shape of [None, {}], got {}".format( + num_outputs, shape)) + if options.get("free_log_std", False): log_std = tf.get_variable( name="log_std", @@ -68,6 +92,80 @@ def __init__(self, self.outputs = tf.concat( [self.outputs, 0.0 * self.outputs + log_std], 1) - def _build_layers(self): - """Builds and returns the output and last layer of the network.""" + def _build_layers(self, inputs, num_outputs, options): + """Builds and returns the output and last layer of the network. + + Deprecated: use _build_layers_v2 instead, which has better support + for dict and tuple spaces. + """ raise NotImplementedError + + def _build_layers_v2(self, input_dict, num_outputs, options): + """Define the layers of a custom model. + + Arguments: + input_dict (dict): Dictionary of input tensors, including "obs", + "prev_action", "prev_reward". + num_outputs (int): Output tensor must be of size + [BATCH_SIZE, num_outputs]. + options (dict): Model options. + + Returns: + (outputs, feature_layer): Tensors of size [BATCH_SIZE, num_outputs] + and [BATCH_SIZE, desired_feature_size]. + + When using dict or tuple observation spaces, you can access + the nested sub-observation batches here as well: + + Examples: + >>> print(input_dict) + {'prev_actions': , + 'prev_rewards': , + 'obs': OrderedDict([ + ('sensors', OrderedDict([ + ('front_cam', [ + , + ]), + ('position', ), + ('velocity', )]))])} + """ + raise NotImplementedError + + +def _restore_original_dimensions(input_dict, obs_space): + if hasattr(obs_space, "original_space"): + return dict( + input_dict, + obs=_unpack_obs(input_dict["obs"], obs_space.original_space)) + return input_dict + + +def _unpack_obs(obs, space): + if (isinstance(space, gym.spaces.Dict) + or isinstance(space, gym.spaces.Tuple)): + prep = get_preprocessor(space)(space) + if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]: + raise ValueError( + "Expected flattened obs shape of [None, {}], got {}".format( + prep.shape[0], obs.shape)) + assert len(prep.preprocessors) == len(space.spaces), \ + (len(prep.preprocessors) == len(space.spaces)) + offset = 0 + if isinstance(space, gym.spaces.Tuple): + u = [] + for p, v in zip(prep.preprocessors, space.spaces): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u.append( + _unpack_obs( + tf.reshape(obs_slice, [-1] + list(p.shape)), v)) + else: + u = OrderedDict() + for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u[k] = _unpack_obs( + tf.reshape(obs_slice, [-1] + list(p.shape)), v) + return u + else: + return obs diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index cd72d1922dcb..b7e084062158 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -16,19 +16,34 @@ class Preprocessor(object): shape (obj): Shape of the preprocessed output. """ - def __init__(self, obs_space, options): + def __init__(self, obs_space, options=None): legacy_patch_shapes(obs_space) self._obs_space = obs_space - self._options = options - self._init() + self._options = options or {} + self.shape = self._init_shape(obs_space, options) - def _init(self): - pass + def _init_shape(self, obs_space, options): + """Returns the shape after preprocessing.""" + raise NotImplementedError def transform(self, observation): """Returns the preprocessed observation.""" raise NotImplementedError + @property + def size(self): + return int(np.product(self.shape)) + + @property + def observation_space(self): + obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32) + # Stash the unwrapped space so that we can unwrap dict and tuple spaces + # automatically in model.py + if (isinstance(self, TupleFlatteningPreprocessor) + or isinstance(self, DictFlatteningPreprocessor)): + obs_space.original_space = self._obs_space + return obs_space + class GenericPixelPreprocessor(Preprocessor): """Generic image preprocessor. @@ -37,19 +52,20 @@ class GenericPixelPreprocessor(Preprocessor): instead for deepmind-style Atari preprocessing. """ - def _init(self): - self._grayscale = self._options.get("grayscale") - self._zero_mean = self._options.get("zero_mean") - self._dim = self._options.get("dim") - self._channel_major = self._options.get("channel_major") + def _init_shape(self, obs_space, options): + self._grayscale = options.get("grayscale") + self._zero_mean = options.get("zero_mean") + self._dim = options.get("dim") + self._channel_major = options.get("channel_major") if self._grayscale: - self.shape = (self._dim, self._dim, 1) + shape = (self._dim, self._dim, 1) else: - self.shape = (self._dim, self._dim, 3) + shape = (self._dim, self._dim, 3) # channel_major requires (# in-channels, row dim, col dim) if self._channel_major: - self.shape = self.shape[-1:] + self.shape[:-1] + shape = shape[-1:] + shape[:-1] + return shape def transform(self, observation): """Downsamples images from (210, 160, 3) by the configured factor.""" @@ -75,16 +91,16 @@ def transform(self, observation): class AtariRamPreprocessor(Preprocessor): - def _init(self): - self.shape = (128, ) + def _init_shape(self, obs_space, options): + return (128, ) def transform(self, observation): return (observation - 128) / 128 class OneHotPreprocessor(Preprocessor): - def _init(self): - self.shape = (self._obs_space.n, ) + def _init_shape(self, obs_space, options): + return (self._obs_space.n, ) def transform(self, observation): arr = np.zeros(self._obs_space.n) @@ -93,8 +109,8 @@ def transform(self, observation): class NoPreprocessor(Preprocessor): - def _init(self): - self.shape = self._obs_space.shape + def _init_shape(self, obs_space, options): + return self._obs_space.shape def transform(self, observation): return observation @@ -103,11 +119,10 @@ def transform(self, observation): class TupleFlatteningPreprocessor(Preprocessor): """Preprocesses each tuple element, then flattens it all into a vector. - If desired, the vector output can be unpacked via tf.reshape() within a - custom model to handle each component separately. + RLlib models will unpack the flattened output before _build_layers_v2(). """ - def _init(self): + def _init_shape(self, obs_space, options): assert isinstance(self._obs_space, gym.spaces.Tuple) size = 0 self.preprocessors = [] @@ -116,17 +131,43 @@ def _init(self): print("Creating sub-preprocessor for", space) preprocessor = get_preprocessor(space)(space, self._options) self.preprocessors.append(preprocessor) - size += np.product(preprocessor.shape) - self.shape = (size, ) + size += preprocessor.size + return (size, ) def transform(self, observation): assert len(observation) == len(self.preprocessors), observation return np.concatenate([ - np.reshape(p.transform(o), [np.product(p.shape)]) + np.reshape(p.transform(o), [p.size]) for (o, p) in zip(observation, self.preprocessors) ]) +class DictFlatteningPreprocessor(Preprocessor): + """Preprocesses each dict value, then flattens it all into a vector. + + RLlib models will unpack the flattened output before _build_layers_v2(). + """ + + def _init_shape(self, obs_space, options): + assert isinstance(self._obs_space, gym.spaces.Dict) + size = 0 + self.preprocessors = [] + for space in self._obs_space.spaces.values(): + print("Creating sub-preprocessor for", space) + preprocessor = get_preprocessor(space)(space, self._options) + self.preprocessors.append(preprocessor) + size += preprocessor.size + return (size, ) + + def transform(self, observation): + assert len(observation) == len(self.preprocessors), \ + (len(observation), len(self.preprocessors)) + return np.concatenate([ + np.reshape(p.transform(o), [p.size]) + for (o, p) in zip(observation.values(), self.preprocessors) + ]) + + def get_preprocessor(space): """Returns an appropriate preprocessor class for the given space.""" @@ -141,6 +182,8 @@ def get_preprocessor(space): preprocessor = AtariRamPreprocessor elif isinstance(space, gym.spaces.Tuple): preprocessor = TupleFlatteningPreprocessor + elif isinstance(space, gym.spaces.Dict): + preprocessor = DictFlatteningPreprocessor else: preprocessor = NoPreprocessor diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 902addb6a31c..4105af7dd367 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -12,7 +12,8 @@ class VisionNetwork(Model): """Generic vision network.""" - def _build_layers(self, inputs, num_outputs, options): + def _build_layers_v2(self, input_dict, num_outputs, options): + inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: filters = get_filter_config(options) diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index 62468e123bca..852a02fc4d1e 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -15,16 +15,18 @@ class CustomPreprocessor(Preprocessor): - pass + def _init_shape(self, obs_space, options): + return None class CustomPreprocessor2(Preprocessor): - pass + def _init_shape(self, obs_space, options): + return None class CustomModel(Model): def _build_layers(self, *args): - return None, None + return tf.constant([[0] * 5]), None class ModelCatalogTest(unittest.TestCase): @@ -69,20 +71,24 @@ def testDefaultModels(self): ray.init() with tf.variable_scope("test1"): - p1 = ModelCatalog.get_model( - np.zeros((10, 3), dtype=np.float32), 5, {}) + p1 = ModelCatalog.get_model({ + "obs": np.zeros((10, 3), dtype=np.float32) + }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): - p2 = ModelCatalog.get_model( - np.zeros((10, 84, 84, 3), dtype=np.float32), 5, {}) + p2 = ModelCatalog.get_model({ + "obs": np.zeros((10, 84, 84, 3), dtype=np.float32) + }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {}) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): ray.init() ModelCatalog.register_custom_model("foo", CustomModel) - p1 = ModelCatalog.get_model( - tf.constant([1, 2, 3]), 5, {"custom_model": "foo"}) + p1 = ModelCatalog.get_model({ + "obs": tf.constant([1, 2, 3]) + }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, + {"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel)) diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 493b338cf8e7..31f3103f5c16 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -314,6 +314,8 @@ class StatefulPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): return [0] * len(obs_batch), [[h] * len(obs_batch)], {} @@ -337,6 +339,8 @@ class ModelBasedPolicyGraph(PGPolicyGraph): def compute_actions(self, obs_batch, state_batches, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): # Pretend we did a model-based rollout and want to return diff --git a/python/ray/rllib/test/test_nested_spaces.py b/python/ray/rllib/test/test_nested_spaces.py new file mode 100644 index 000000000000..f7f5f5981972 --- /dev/null +++ b/python/ray/rllib/test/test_nested_spaces.py @@ -0,0 +1,249 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pickle + +from gym import spaces +from gym.envs.registration import EnvSpec +import gym +import tensorflow.contrib.slim as slim +import tensorflow as tf +import unittest + +import ray +from ray.rllib.agents.pg import PGAgent +from ray.rllib.env.async_vector_env import AsyncVectorEnv +from ray.rllib.env.vector_env import VectorEnv +from ray.rllib.models import ModelCatalog +from ray.rllib.models.model import Model +from ray.rllib.test.test_serving_env import SimpleServing +from ray.tune.registry import register_env + +DICT_SPACE = spaces.Dict({ + "sensors": spaces.Dict({ + "position": spaces.Box(low=-100, high=100, shape=(3, )), + "velocity": spaces.Box(low=-1, high=1, shape=(3, )), + "front_cam": spaces.Tuple( + (spaces.Box(low=0, high=1, shape=(10, 10, 3)), + spaces.Box(low=0, high=1, shape=(10, 10, 3)))), + "rear_cam": spaces.Box(low=0, high=1, shape=(10, 10, 3)), + }), + "inner_state": spaces.Dict({ + "charge": spaces.Discrete(100), + "job_status": spaces.Dict({ + "task": spaces.Discrete(5), + "progress": spaces.Box(low=0, high=100, shape=()), + }) + }) +}) + +DICT_SAMPLES = [DICT_SPACE.sample() for _ in range(10)] + +TUPLE_SPACE = spaces.Tuple([ + spaces.Box(low=-100, high=100, shape=(3, )), + spaces.Tuple((spaces.Box(low=0, high=1, shape=(10, 10, 3)), + spaces.Box(low=0, high=1, shape=(10, 10, 3)))), + spaces.Discrete(5), +]) + +TUPLE_SAMPLES = [TUPLE_SPACE.sample() for _ in range(10)] + + +def one_hot(i, n): + out = [0.0] * n + out[i] = 1.0 + return out + + +class NestedDictEnv(gym.Env): + def __init__(self): + self.action_space = spaces.Discrete(2) + self.observation_space = DICT_SPACE + self._spec = EnvSpec("NestedDictEnv-v0") + self.steps = 0 + + def reset(self): + self.steps = 0 + return DICT_SAMPLES[0] + + def step(self, action): + self.steps += 1 + return DICT_SAMPLES[self.steps], 1, self.steps >= 5, {} + + +class NestedTupleEnv(gym.Env): + def __init__(self): + self.action_space = spaces.Discrete(2) + self.observation_space = TUPLE_SPACE + self._spec = EnvSpec("NestedTupleEnv-v0") + self.steps = 0 + + def reset(self): + self.steps = 0 + return TUPLE_SAMPLES[0] + + def step(self, action): + self.steps += 1 + return TUPLE_SAMPLES[self.steps], 1, self.steps >= 5, {} + + +class InvalidModel(Model): + def _build_layers_v2(self, input_dict, num_outputs, options): + return "not", "valid" + + +class DictSpyModel(Model): + capture_index = 0 + + def _build_layers_v2(self, input_dict, num_outputs, options): + def spy(pos, front_cam, task): + # TF runs this function in an isolated context, so we have to use + # redis to communicate back to our suite + ray.experimental.internal_kv._internal_kv_put( + "d_spy_in_{}".format(DictSpyModel.capture_index), + pickle.dumps((pos, front_cam, task))) + DictSpyModel.capture_index += 1 + return 0 + + spy_fn = tf.py_func( + spy, [ + input_dict["obs"]["sensors"]["position"], + input_dict["obs"]["sensors"]["front_cam"][0], + input_dict["obs"]["inner_state"]["job_status"]["task"] + ], + tf.int64, + stateful=True) + + with tf.control_dependencies([spy_fn]): + output = slim.fully_connected( + input_dict["obs"]["sensors"]["position"], num_outputs) + return output, output + + +class TupleSpyModel(Model): + capture_index = 0 + + def _build_layers_v2(self, input_dict, num_outputs, options): + def spy(pos, cam, task): + # TF runs this function in an isolated context, so we have to use + # redis to communicate back to our suite + ray.experimental.internal_kv._internal_kv_put( + "t_spy_in_{}".format(TupleSpyModel.capture_index), + pickle.dumps((pos, cam, task))) + TupleSpyModel.capture_index += 1 + return 0 + + spy_fn = tf.py_func( + spy, [ + input_dict["obs"][0], + input_dict["obs"][1][0], + input_dict["obs"][2], + ], + tf.int64, + stateful=True) + + with tf.control_dependencies([spy_fn]): + output = slim.fully_connected(input_dict["obs"][0], num_outputs) + return output, output + + +class NestedSpacesTest(unittest.TestCase): + def testInvalidModel(self): + ModelCatalog.register_custom_model("invalid", InvalidModel) + self.assertRaises(ValueError, lambda: PGAgent( + env="CartPole-v0", config={ + "model": { + "custom_model": "invalid", + }, + })) + + def doTestNestedDict(self, make_env): + ModelCatalog.register_custom_model("composite", DictSpyModel) + register_env("nested", make_env) + pg = PGAgent( + env="nested", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "model": { + "custom_model": "composite", + }, + }) + pg.train() + + # Check that the model sees the correct reconstructed observations + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "d_spy_in_{}".format(i))) + pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist() + cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist() + task_i = one_hot( + DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + def doTestNestedTuple(self, make_env): + ModelCatalog.register_custom_model("composite2", TupleSpyModel) + register_env("nested2", make_env) + pg = PGAgent( + env="nested2", + config={ + "num_workers": 0, + "sample_batch_size": 5, + "model": { + "custom_model": "composite2", + }, + }) + pg.train() + + # Check that the model sees the correct reconstructed observations + for i in range(4): + seen = pickle.loads( + ray.experimental.internal_kv._internal_kv_get( + "t_spy_in_{}".format(i))) + pos_i = TUPLE_SAMPLES[i][0].tolist() + cam_i = TUPLE_SAMPLES[i][1][0].tolist() + task_i = one_hot(TUPLE_SAMPLES[i][2], 5) + self.assertEqual(seen[0][0].tolist(), pos_i) + self.assertEqual(seen[1][0].tolist(), cam_i) + self.assertEqual(seen[2][0].tolist(), task_i) + + def testNestedDictGym(self): + self.doTestNestedDict(lambda _: NestedDictEnv()) + + def testNestedDictVector(self): + self.doTestNestedDict( + lambda _: VectorEnv.wrap(lambda i: NestedDictEnv())) + + def testNestedDictServing(self): + self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv())) + + def testNestedDictAsync(self): + self.assertRaisesRegexp( + ValueError, "Found raw Dict space.*", + lambda: self.doTestNestedDict( + lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))) + + def testNestedTupleGym(self): + self.doTestNestedTuple(lambda _: NestedTupleEnv()) + + def testNestedTupleVector(self): + self.doTestNestedTuple( + lambda _: VectorEnv.wrap(lambda i: NestedTupleEnv())) + + def testNestedTupleServing(self): + self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv())) + + def testNestedTupleAsync(self): + self.assertRaisesRegexp( + ValueError, "Found raw Tuple space.*", + lambda: self.doTestNestedTuple( + lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))) + + +if __name__ == "__main__": + ray.init(num_cpus=5) + unittest.main(verbosity=2) diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index c4c2baf6e18e..b906136038c4 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -3,6 +3,7 @@ from __future__ import print_function import gym +import numpy as np import time import unittest @@ -21,6 +22,8 @@ class MockPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): return [0] * len(obs_batch), [], {} @@ -33,6 +36,8 @@ class BadPolicyGraph(PolicyGraph): def compute_actions(self, obs_batch, state_batches, + prev_action_batch=None, + prev_reward_batch=None, is_training=False, episodes=None): raise Exception("intentional error") @@ -107,8 +112,23 @@ def testBasic(self): env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch = ev.sample() - for key in ["obs", "actions", "rewards", "dones", "advantages"]: + for key in [ + "obs", "actions", "rewards", "dones", "advantages", + "prev_rewards", "prev_actions" + ]: self.assertIn(key, batch) + + def to_prev(vec): + out = np.zeros_like(vec) + for i, v in enumerate(vec): + if i + 1 < len(out) and not batch["dones"][i]: + out[i + 1] = v + return out.tolist() + + self.assertEqual(batch["prev_rewards"].tolist(), + to_prev(batch["rewards"])) + self.assertEqual(batch["prev_actions"].tolist(), + to_prev(batch["actions"])) self.assertGreater(batch["advantages"][0], 1) def testGlobalVarsUpdate(self): diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 20ef872ae86e..4f1aee0120e9 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -2,7 +2,7 @@ import traceback import gym -from gym.spaces import Box, Discrete, Tuple +from gym.spaces import Box, Discrete, Tuple, Dict from gym.envs.registration import EnvSpec import numpy as np import sys @@ -14,33 +14,28 @@ ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5, ), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ]), - "mixed_tuple": Tuple( + "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "tuple": Tuple( [Discrete(2), Discrete(3), - Box(0.0, 1.0, (5, ), dtype=np.float32)]), + Box(-1.0, 1.0, (5, ), dtype=np.float32)]), } OBSERVATION_SPACES_TO_TEST = { "discrete": Discrete(5), - "vector": Box(0.0, 1.0, (5, ), dtype=np.float32), - "image": Box(0.0, 1.0, (84, 84, 1), dtype=np.float32), - "atari": Box(0.0, 1.0, (210, 160, 3), dtype=np.float32), - "atari_ram": Box(0.0, 1.0, (128, ), dtype=np.float32), - "simple_tuple": Tuple([ - Box(0.0, 1.0, (5, ), dtype=np.float32), - Box(0.0, 1.0, (5, ), dtype=np.float32) - ]), - "mixed_tuple": Tuple( - [Discrete(10), Box(0.0, 1.0, (5, ), dtype=np.float32)]), + "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32), + "atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32), + "tuple": Tuple([Discrete(10), + Box(-1.0, 1.0, (5, ), dtype=np.float32)]), + "dict": Dict({ + "task": Discrete(10), + "position": Box(-1.0, 1.0, (5, ), dtype=np.float32), + }), } -def make_stub_env(action_space, obs_space): +def make_stub_env(action_space, obs_space, check_action_bounds): class StubEnv(gym.Env): def __init__(self): self.action_space = action_space @@ -52,16 +47,23 @@ def reset(self): return sample def step(self, action): + if check_action_bounds and not self.action_space.contains(action): + raise ValueError("Illegal action for {}: {}".format( + self.action_space, action)) + if (isinstance(self.action_space, Tuple) + and len(action) != len(self.action_space.spaces)): + raise ValueError("Illegal action for {}: {}".format( + self.action_space, action)) return self.observation_space.sample(), 1, True, {} return StubEnv -def check_support(alg, config, stats): +def check_support(alg, config, stats, check_bounds=False): for a_name, action_space in ACTION_SPACES_TO_TEST.items(): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing", alg, action_space, obs_space, "===") - stub_env = make_stub_env(action_space, obs_space) + stub_env = make_stub_env(action_space, obs_space, check_bounds) register_env("stub_env", lambda c: stub_env()) stat = "ok" a = None @@ -105,8 +107,13 @@ def testAll(self): "num_sgd_iter": 1, "train_batch_size": 10, "sample_batch_size": 10, - "sgd_minibatch_size": 1 - }, stats) + "sgd_minibatch_size": 1, + "model": { + "squash_to_range": True + }, + }, + stats, + check_bounds=True) check_support( "ES", { "num_workers": 1, diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index 030642ae5b6a..2ea3ba7b8042 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -26,7 +26,8 @@ def __init__(self, session, debug_name): def add_feed_dict(self, feed_dict): assert not self._executed for k in feed_dict: - assert k not in self.feed_dict + if k in self.feed_dict: + raise ValueError("Key added twice: {}".format(k)) self.feed_dict.update(feed_dict) def add_fetches(self, fetches): diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 0a6db03b90cd..821fefa7bdfa 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -246,6 +246,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_policy_evaluator.py +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/test/test_nested_spaces.py + docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_serving_env.py @@ -314,6 +317,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=IMPALA --stop=100 +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=200 --use-prev-action-reward + docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 From a4db5bbaea75612bed3201242b48ecce301b1a15 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Sun, 21 Oct 2018 11:12:20 +0800 Subject: [PATCH 044/215] Fill driver id into actor notification when finishing assigned task. (#3080) ## What do these changes do? Fill driver id into actor notification when finishing assigned task. Also it improves codes. --- src/ray/raylet/node_manager.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index be1974468439..0db61e7ddc2d 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1370,6 +1370,7 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // If this was an actor creation task, then convert the worker to an actor. auto actor_id = task.GetTaskSpecification().ActorCreationId(); worker.AssignActorId(actor_id); + const auto driver_id = task.GetTaskSpecification().DriverId(); // Publish the actor creation event to all other nodes so that methods for // the actor will be forwarded directly to this node. @@ -1377,11 +1378,10 @@ void NodeManager::FinishAssignedTask(Worker &worker) { actor_notification->actor_id = actor_id.binary(); actor_notification->actor_creation_dummy_object_id = task.GetTaskSpecification().ActorDummyObject().binary(); - // TODO(swang): The driver ID. - actor_notification->driver_id = JobID::nil().binary(); + actor_notification->driver_id = driver_id.binary(); actor_notification->node_manager_id = gcs_client_->client_table().GetLocalClientId().binary(); - auto driver_id = task.GetTaskSpecification().DriverId(); + RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id << " driver_id: " << driver_id; RAY_CHECK_OK(gcs_client_->actor_table().Append(JobID::nil(), actor_id, From 40c4148d4f065802f7888017ef26c2213b45341b Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 20 Oct 2018 22:56:29 -0700 Subject: [PATCH 045/215] Cluster Utilities for Fault Tolerance Tests (#3008) --- .travis.yml | 2 + python/ray/test/cluster_utils.py | 201 +++++++++++++++++++++++++++++++ python/ray/worker.py | 4 - test/multi_node_test_2.py | 72 +++++++++++ 4 files changed, 275 insertions(+), 4 deletions(-) create mode 100644 python/ray/test/cluster_utils.py create mode 100644 test/multi_node_test_2.py diff --git a/.travis.yml b/.travis.yml index bd0fd929b733..70f548f83bf8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -146,6 +146,7 @@ matrix: - python -m pytest -v test/stress_tests.py - pytest test/component_failures_test.py - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - pytest test/monitor_test.py - python -m pytest -v test/cython_test.py @@ -223,6 +224,7 @@ script: - python -m pytest -v test/stress_tests.py - python -m pytest -v test/component_failures_test.py - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - python -m pytest -v test/monitor_test.py - python -m pytest -v test/cython_test.py diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py new file mode 100644 index 000000000000..7e7a82d67c6e --- /dev/null +++ b/python/ray/test/cluster_utils.py @@ -0,0 +1,201 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import ray +import ray.services as services + +logger = logging.getLogger(__name__) + + +class Cluster(object): + def __init__(self, + initialize_head=False, + connect=False, + head_node_args=None): + """Initializes the cluster. + + Args: + initialize_head (bool): Automatically start a Ray cluster + by initializing the head node. Defaults to False. + connect (bool): If `initialize_head=True` and `connect=True`, + ray.init will be called with the redis address of this cluster + passed in. + head_node_args (kwargs): Arguments to be passed into + `start_ray_head` via `self.add_node`. + """ + self.head_node = None + self.worker_nodes = {} + self.redis_address = None + if not initialize_head and connect: + raise RuntimeError("Cannot connect to uninitialized cluster.") + + if initialize_head: + head_node_args = head_node_args or {} + self.add_node(**head_node_args) + if connect: + ray.init(redis_address=self.redis_address) + + def add_node(self, **override_kwargs): + """Adds a node to the local Ray Cluster. + + All nodes are by default started with the following settings: + cleanup=True, + use_raylet=True, + resources={"CPU": 1}, + object_store_memory=100 * (2**20) # 100 MB + + Args: + override_kwargs: Keyword arguments used in `start_ray_head` + and `start_ray_node`. Overrides defaults. + + Returns: + Node object of the added Ray node. + """ + node_kwargs = dict( + cleanup=True, + use_raylet=True, + resources={"CPU": 1}, + object_store_memory=100 * (2**20) # 100 MB + ) + node_kwargs.update(override_kwargs) + + if self.head_node is None: + address_info = services.start_ray_head( + node_ip_address=services.get_node_ip_address(), + include_webui=False, + **node_kwargs) + self.redis_address = address_info["redis_address"] + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.head_node = node + else: + address_info = services.start_ray_node( + services.get_node_ip_address(), self.redis_address, + **node_kwargs) + # TODO(rliaw): Find a more stable way than modifying global state. + process_dict_copy = services.all_processes.copy() + for key in services.all_processes: + services.all_processes[key] = [] + node = Node(process_dict_copy) + self.worker_nodes[node] = address_info + logging.info("Starting Node with raylet socket {}".format( + address_info["raylet_socket_names"])) + + return node + + def remove_node(self, node): + """Kills all processes associated with worker node. + + Args: + node (Node): Worker node of which all associated processes + will be removed. + """ + if self.head_node == node: + self.head_node.kill_all_processes() + self.head_node = None + # TODO(rliaw): Do we need to kill all worker processes? + else: + node.kill_all_processes() + self.worker_nodes.pop(node) + + assert not node.any_processes_alive(), ( + "There are zombie processes left over after killing.") + + def wait_for_nodes(self, retries=20): + """Waits for all nodes to be registered with global state. + + Args: + retries (int): Number of times to retry checking client table. + """ + for i in range(retries): + if not ray.is_initialized() or not self._check_registered_nodes(): + time.sleep(0.3) + else: + break + + def _check_registered_nodes(self): + registered = len([ + client for client in ray.global_state.client_table() + if client["IsInsertion"] + ]) + expected = len(self.list_all_nodes()) + if registered == expected: + logger.info("All nodes registered as expected.") + else: + logger.info("Currently registering {} but expecting {}".format( + registered, expected)) + return registered == expected + + def list_all_nodes(self): + """Lists all nodes. + + TODO(rliaw): What is the desired behavior if a head node + dies before worker nodes die? + + Returns: + List of all nodes, including the head node. + """ + nodes = list(self.worker_nodes) + if self.head_node: + nodes = [self.head_node] + nodes + return nodes + + def shutdown(self): + # We create a list here as a copy because `remove_node` + # modifies `self.worker_nodes`. + all_nodes = list(self.worker_nodes) + for node in all_nodes: + self.remove_node(node) + self.remove_node(self.head_node) + + +class Node(object): + """Abstraction for a Ray node.""" + + def __init__(self, process_dict): + # TODO(rliaw): Is there a unique identifier for a node? + self.process_dict = process_dict + + def kill_plasma_store(self): + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].kill() + self.process_dict[services.PROCESS_TYPE_PLASMA_STORE][0].wait() + + def kill_raylet(self): + self.process_dict[services.PROCESS_TYPE_RAYLET][0].kill() + self.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + + def kill_log_monitor(self): + self.process_dict["log_monitor"][0].kill() + self.process_dict["log_monitor"][0].wait() + + def kill_all_processes(self): + for process_name, process_list in self.process_dict.items(): + logger.info("Killing all {}(s)".format(process_name)) + for process in process_list: + process.kill() + + for process_name, process_list in self.process_dict.items(): + logger.info("Waiting all {}(s)".format(process_name)) + for process in process_list: + process.wait() + + def live_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is None] + + def dead_processes(self): + return [(p_name, proc) for p_name, p_list in self.process_dict.items() + for proc in p_list if proc.poll() is not None] + + def any_processes_alive(self): + return any(self.live_processes()) + + def all_processes_alive(self): + return not any(self.dead_processes()) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7049b1f3a429..e19c433753b7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1837,10 +1837,6 @@ def shutdown(worker=global_worker): worker.plasma_client.disconnect() if worker.mode == SCRIPT_MODE: - # If this is a driver, push the finish time to Redis and clean up any - # other services that were started with the driver. - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, - {"end_time": time.time()}) services.cleanup() else: # If this is not a driver, make sure there are no orphan processes, diff --git a/test/multi_node_test_2.py b/test/multi_node_test_2.py new file mode 100644 index 000000000000..bb86bb2a7f53 --- /dev/null +++ b/test/multi_node_test_2.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import pytest + +import ray +import ray.services as services +from ray.test.cluster_utils import Cluster + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def start_connected_cluster(): + # Start the Ray processes. + g = Cluster(initialize_head=True, connect=True) + yield g + # The code after the yield will run as teardown code. + ray.shutdown() + g.shutdown() + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_cluster(): + """Basic test for adding and removing nodes in cluster.""" + g = Cluster(initialize_head=False) + node = g.add_node() + node2 = g.add_node() + assert node.all_processes_alive() + assert node2.all_processes_alive() + g.remove_node(node2) + g.remove_node(node) + assert not any(node.any_processes_alive() for node in g.list_all_nodes()) + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_wait_for_nodes(start_connected_cluster): + """Unit test for `Cluster.wait_for_nodes`. + + Adds 4 workers, waits, then removes 4 workers, waits, + then adds 1 worker, waits, and removes 1 worker, waits. + """ + cluster = start_connected_cluster + workers = [cluster.add_node() for i in range(4)] + cluster.wait_for_nodes() + [cluster.remove_node(w) for w in workers] + cluster.wait_for_nodes() + worker2 = cluster.add_node() + cluster.wait_for_nodes() + cluster.remove_node(worker2) + cluster.wait_for_nodes() + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_worker_plasma_store_failure(start_connected_cluster): + cluster = start_connected_cluster + worker = cluster.add_node() + cluster.wait_for_nodes() + # Log monitor doesn't die for some reason + worker.kill_log_monitor() + worker.kill_plasma_store() + worker.process_dict[services.PROCESS_TYPE_RAYLET][0].wait() + assert not worker.any_processes_alive(), worker.live_processes() From 221d1663c155c451e6a45ac26e3bddd8a082c6f2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 21 Oct 2018 23:43:57 -0700 Subject: [PATCH 046/215] [rllib] switch to python logger (#3098) * logg * set rllib logger * comment * info * rlib * comment * add format * fix lint * add file info * update * add ts * lint * better docs * fix value error * soft log level --- python/ray/rllib/__init__.py | 14 ++ python/ray/rllib/agents/a3c/a3c.py | 3 +- python/ray/rllib/agents/agent.py | 55 +++-- python/ray/rllib/agents/ars/ars.py | 27 +-- python/ray/rllib/agents/ddpg/ddpg.py | 3 +- python/ray/rllib/agents/dqn/apex.py | 3 +- python/ray/rllib/agents/dqn/dqn.py | 3 +- python/ray/rllib/agents/es/es.py | 30 +-- python/ray/rllib/agents/es/tabular_logger.py | 229 ------------------ python/ray/rllib/agents/impala/impala.py | 3 +- python/ray/rllib/agents/pg/pg.py | 3 +- python/ray/rllib/agents/ppo/ppo.py | 9 +- .../ray/rllib/evaluation/policy_evaluator.py | 8 +- python/ray/rllib/evaluation/sampler.py | 5 +- python/ray/rllib/models/catalog.py | 12 +- python/ray/rllib/models/preprocessors.py | 8 +- python/ray/rllib/models/pytorch/fcnet.py | 6 +- .../optimizers/async_samples_optimizer.py | 7 +- .../rllib/optimizers/multi_gpu_optimizer.py | 9 +- .../ray/rllib/optimizers/policy_optimizer.py | 4 + .../optimizers/sync_samples_optimizer.py | 5 +- python/ray/rllib/utils/actors.py | 6 +- python/ray/rllib/utils/compression.py | 9 +- python/ray/rllib/utils/policy_client.py | 10 +- python/ray/rllib/utils/tf_run_builder.py | 9 +- 25 files changed, 160 insertions(+), 320 deletions(-) delete mode 100644 python/ray/rllib/agents/es/tabular_logger.py diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index db9f52687126..b3155f2dc237 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import logging + # Note: do not introduce unnecessary library dependencies here, e.g. gym. # This file is imported from the tune module in order to register RLlib agents. from ray.tune.registry import register_trainable @@ -16,6 +18,17 @@ from ray.rllib.evaluation.sample_batch import SampleBatch +def _setup_logger(): + logger = logging.getLogger("ray.rllib") + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter( + "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s" + )) + logger.addHandler(handler) + logger.propagate = False + + def _register_all(): for key in [ @@ -27,6 +40,7 @@ def _register_all(): register_trainable(key, get_agent_class(key)) +_setup_logger() _register_all() __all__ = [ diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 55f179adecd5..5788f338138c 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -10,6 +10,7 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # Size of rollout batch @@ -36,8 +37,8 @@ # sample_batch_size by up to 5x due to async buffering of batches. "sample_async": True, }) - # __sphinx_doc_end__ +# yapf: enable class A3CAgent(Agent): diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 17d7198518e0..c082c1d28290 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -4,6 +4,7 @@ import copy import os +import logging import pickle import tempfile from datetime import datetime @@ -19,12 +20,38 @@ from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR +# yapf: disable # __sphinx_doc_begin__ COMMON_CONFIG = { + # === Debugging === + # Whether to write episode stats and videos to the agent log dir + "monitor": False, + # Set the RLlib log level for the agent process and its remote evaluators + "log_level": "INFO", + + # === Policy === + # Arguments to pass to model. See models/catalog.py for a full list of the + # available model options. + "model": MODEL_DEFAULTS, + # Arguments to pass to the policy optimizer. These vary by optimizer. + "optimizer": {}, + + # === Environment === # Discount factor of the MDP "gamma": 0.99, # Number of steps after which the episode is forced to terminate "horizon": None, + # Arguments to pass to the env creator + "env_config": {}, + # Environment name can also be passed via config + "env": None, + # Whether to clip rewards prior to experience postprocessing. Setting to + # None means clip for Atari only. + "clip_rewards": None, + # Whether to use rllib or deepmind preprocessors by default + "preprocessor_pref": "deepmind", + + # === Execution === # Number of environments to evaluate vectorwise per worker. "num_envs_per_worker": 1, # Number of actors used for parallelism @@ -42,20 +69,6 @@ "observation_filter": "NoFilter", # Whether to synchronize the statistics of remote filters. "synchronize_filters": True, - # Whether to clip rewards prior to experience postprocessing. Setting to - # None means clip for Atari only. - "clip_rewards": None, - # Whether to use rllib or deepmind preprocessors - "preprocessor_pref": "deepmind", - # Arguments to pass to the env creator - "env_config": {}, - # Environment name can also be passed via config - "env": None, - # Arguments to pass to model. See models/catalog.py for a full list of the - # available model options. - "model": MODEL_DEFAULTS, - # Arguments to pass to the policy optimizer. These vary by optimizer. - "optimizer": {}, # Configure TF for single-process operation by default "tf_session_args": { # note: parallelism_threads is set to auto for the local evaluator @@ -72,8 +85,6 @@ }, # Whether to LZ4 compress observations "compress_observations": False, - # Whether to write episode stats and videos to the agent log dir - "monitor": False, # Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs) "gpu_fraction": 1, @@ -88,8 +99,8 @@ "policies_to_train": None, }, } - # __sphinx_doc_end__ +# yapf: enable def with_common_config(extra_config): @@ -170,7 +181,8 @@ def session_creator(): model_config=config["model"], policy_config=config, worker_index=worker_index, - monitor_path=self.logdir if config["monitor"] else None) + monitor_path=self.logdir if config["monitor"] else None, + log_level=config["log_level"]) @classmethod def resource_help(cls, config): @@ -197,13 +209,12 @@ def __init__(self, config=None, env=None, logger_creator=None): # Agents allow env ids to be passed directly to the constructor. self._env_id = env or config.get("env") - if not self._env_id: - raise ValueError("Must specify env (str) when creating agent") # Create a default logger creator if no logger_creator is specified if logger_creator is None: timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - logdir_prefix = '_'.join([self._agent_name, self._env_id, timestr]) + logdir_prefix = "{}_{}_{}".format( + [self._agent_name, self._env_id, timestr]) def default_logger_creator(config): """Creates a Unified logger with a default logdir prefix @@ -256,6 +267,8 @@ def _setup(self, config): self._allow_unknown_configs, self._allow_unknown_subkeys) self.config = merged_config + if self.config.get("log_level"): + logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) # TODO(ekl) setting the graph is unnecessary for PyTorch agents with tf.Graph().as_default(): diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 0c9af4ddaed9..67e87057fce0 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -7,6 +7,7 @@ from __future__ import print_function from collections import namedtuple +import logging import numpy as np import time @@ -16,14 +17,16 @@ from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies -from ray.rllib.agents.es import tabular_logger as tlogger from ray.rllib.agents.ars import utils +logger = logging.getLogger(__name__) + Result = namedtuple("Result", [ "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", "eval_returns", "eval_lengths" ]) +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ "noise_stdev": 0.02, # std deviation of parameter noise @@ -38,6 +41,7 @@ "offset": 0, }) # __sphinx_doc_end__ +# yapf: enable @ray.remote @@ -163,12 +167,12 @@ def _init(self): self.report_length = self.config["report_length"] # Create the shared noise table. - print("Creating shared noise table.") + logger.info("Creating shared noise table.") noise_id = create_shared_noise.remote(self.config["noise_size"]) self.noise = SharedNoiseTable(ray.get(noise_id)) # Create the actors. - print("Creating actors.") + logger.info("Creating actors.") self.workers = [ Worker.remote(self.config, self.env_creator, noise_id) for _ in range(self.config["num_workers"]) @@ -182,8 +186,9 @@ def _collect_results(self, theta_id, min_episodes): num_episodes, num_timesteps = 0, 0 results = [] while num_episodes < min_episodes: - print("Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) + logger.info( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) rollout_ids = [ worker.do_rollouts.remote(theta_id) for worker in self.workers ] @@ -263,7 +268,6 @@ def _train(self): g /= np.std(noisy_returns) assert (g.shape == (self.policy.num_params, ) and g.dtype == np.float32) - print('the number of policy params is, ', self.policy.num_params) # Compute the new weights theta. theta, update_ratio = self.optimizer.update(-g) # Set the new weights in the local copy of the policy. @@ -272,18 +276,9 @@ def _train(self): if len(all_eval_returns) > 0: self.reward_list.append(eval_returns.mean()) - tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean()) - tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std()) - tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean()) - - tlogger.record_tabular("WeightsNorm", float(np.square(theta).sum())) - tlogger.record_tabular("WeightsStd", float(np.std(theta))) - tlogger.record_tabular("Grad2Norm", float(np.sqrt(np.square(g).sum()))) - tlogger.record_tabular("UpdateRatio", float(update_ratio)) - tlogger.dump_tabular() - info = { "weights_norm": np.square(theta).sum(), + "weights_std": np.std(theta), "grad_norm": np.square(g).sum(), "update_ratio": update_ratio, "episodes_this_iter": noisy_lengths.size, diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index c35fdaa71d17..ed58718b4ee4 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -13,6 +13,7 @@ "train_batch_size", "learning_starts" ] +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Model === @@ -108,8 +109,8 @@ # Prevent iterations from going lower than this time span "min_iter_time_s": 1, }) - # __sphinx_doc_end__ +# yapf: enable class DDPGAgent(DQNAgent): diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 052d0fd3e957..ac8ec4490e60 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -6,6 +6,7 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# yapf: disable # __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( DQN_CONFIG, # see also the options in dqn.py, which are also supported @@ -31,8 +32,8 @@ "min_iter_time_s": 30, }, ) - # __sphinx_doc_end__ +# yapf: enable class ApexAgent(DQNAgent): diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index f86b286ce5a1..3120407772df 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -20,6 +20,7 @@ "learning_starts" ] +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # === Model === @@ -116,8 +117,8 @@ # Prevent iterations from going lower than this time span "min_iter_time_s": 1, }) - # __sphinx_doc_end__ +# yapf: enable class DQNAgent(Agent): diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 2fe63c71a2b1..ed2ed18692ca 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -6,6 +6,7 @@ from __future__ import print_function from collections import namedtuple +import logging import numpy as np import time @@ -15,15 +16,17 @@ from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies -from ray.rllib.agents.es import tabular_logger as tlogger from ray.rllib.agents.es import utils from ray.rllib.utils import merge_dicts +logger = logging.getLogger(__name__) + Result = namedtuple("Result", [ "noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths", "eval_returns", "eval_lengths" ]) +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ "l2_coeff": 0.005, @@ -39,6 +42,7 @@ "report_length": 10, }) # __sphinx_doc_end__ +# yapf: enable @ray.remote @@ -169,12 +173,12 @@ def _init(self): self.report_length = self.config["report_length"] # Create the shared noise table. - print("Creating shared noise table.") + logger.info("Creating shared noise table.") noise_id = create_shared_noise.remote(self.config["noise_size"]) self.noise = SharedNoiseTable(ray.get(noise_id)) # Create the actors. - print("Creating actors.") + logger.info("Creating actors.") self.workers = [ Worker.remote(self.config, policy_params, self.env_creator, noise_id) for _ in range(self.config["num_workers"]) @@ -188,8 +192,9 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps): num_episodes, num_timesteps = 0, 0 results = [] while num_episodes < min_episodes or num_timesteps < min_timesteps: - print("Collected {} episodes {} timesteps so far this iter".format( - num_episodes, num_timesteps)) + logger.info( + "Collected {} episodes {} timesteps so far this iter".format( + num_episodes, num_timesteps)) rollout_ids = [ worker.do_rollouts.remote(theta_id) for worker in self.workers ] @@ -269,21 +274,6 @@ def _train(self): if len(all_eval_returns) > 0: self.reward_list.append(np.mean(eval_returns)) - tlogger.record_tabular("EvalEpRewStd", eval_returns.std()) - tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean()) - - tlogger.record_tabular("EpRewMean", noisy_returns.mean()) - tlogger.record_tabular("EpRewStd", noisy_returns.std()) - tlogger.record_tabular("EpLenMean", noisy_lengths.mean()) - - tlogger.record_tabular("Norm", float(np.square(theta).sum())) - tlogger.record_tabular("GradNorm", float(np.square(g).sum())) - tlogger.record_tabular("UpdateRatio", float(update_ratio)) - - tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size) - tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far) - tlogger.dump_tabular() - info = { "weights_norm": np.square(theta).sum(), "grad_norm": np.square(g).sum(), diff --git a/python/ray/rllib/agents/es/tabular_logger.py b/python/ray/rllib/agents/es/tabular_logger.py deleted file mode 100644 index 1463e59e0704..000000000000 --- a/python/ray/rllib/agents/es/tabular_logger.py +++ /dev/null @@ -1,229 +0,0 @@ -# Code in this file is copied and adapted from -# https://github.com/openai/evolution-strategies-starter. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import OrderedDict -import os -import sys -import time - -import tensorflow as tf -from tensorflow.core.util import event_pb2 -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.util import compat - -DEBUG = 10 -INFO = 20 -WARN = 30 -ERROR = 40 - -DISABLED = 50 - - -class TbWriter(object): - """Based on SummaryWriter, but changed to allow for a different prefix.""" - - def __init__(self, dir, prefix): - self.dir = dir - # Start at 1, because EvWriter automatically generates an object with - # step = 0. - self.step = 1 - self.evwriter = pywrap_tensorflow.EventsWriter( - compat.as_bytes(os.path.join(dir, prefix))) - - def write_values(self, key2val): - summary = tf.Summary(value=[ - tf.Summary.Value(tag=k, simple_value=float(v)) - for (k, v) in key2val.items() - ]) - event = event_pb2.Event(wall_time=time.time(), summary=summary) - event.step = self.step - self.evwriter.WriteEvent(event) - self.evwriter.Flush() - self.step += 1 - - def close(self): - self.evwriter.Close() - - -# API - - -def start(dir): - if _Logger.CURRENT is not _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to start logging (dir=%s), but " - "you never stopped the previous logger (dir=%s)." - "\n" % (dir, _Logger.CURRENT.dir)) - _Logger.CURRENT = _Logger(dir=dir) - - -def stop(): - if _Logger.CURRENT is _Logger.DEFAULT: - sys.stderr.write("WARNING: You asked to stop logging, but you never " - "started any previous logger." - "\n" % (dir, _Logger.CURRENT.dir)) - return - _Logger.CURRENT.close() - _Logger.CURRENT = _Logger.DEFAULT - - -def record_tabular(key, val): - """Log a value of some diagnostic. - - Call this once for each diagnostic quantity, each iteration. - """ - _Logger.CURRENT.record_tabular(key, val) - - -def dump_tabular(): - """Write all of the diagnostics from the current iteration.""" - _Logger.CURRENT.dump_tabular() - - -def log(*args, **kwargs): - """Write the sequence of args, with no separators. - - This is written to the console and output files (if you've configured an - output file). - """ - level = kwargs['level'] if 'level' in kwargs else INFO - _Logger.CURRENT.log(*args, level=level) - - -def debug(*args): - log(*args, level=DEBUG) - - -def info(*args): - log(*args, level=INFO) - - -def warn(*args): - log(*args, level=WARN) - - -def error(*args): - log(*args, level=ERROR) - - -def set_level(level): - """ - Set logging threshold on current logger. - """ - _Logger.CURRENT.set_level(level) - - -def get_dir(): - """ - Get directory that log files are being written to. - will be None if there is no output directory (i.e., if you didn't call - start) - """ - return _Logger.CURRENT.get_dir() - - -def get_expt_dir(): - sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n") - return get_dir() - - -# Backend - - -class _Logger(object): - # A logger with no output files. (See right below class definition) so that - # you can still log to the terminal without setting up any output files. - DEFAULT = None - # Current logger being used by the free functions above. - CURRENT = None - - def __init__(self, dir=None): - self.name2val = OrderedDict() # Values this iteration. - self.level = INFO - self.dir = dir - self.text_outputs = [sys.stdout] - if dir is not None: - os.makedirs(dir, exist_ok=True) - self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w")) - self.tbwriter = TbWriter(dir=dir, prefix="events") - else: - self.tbwriter = None - - # Logging API, forwarded - - def record_tabular(self, key, val): - self.name2val[key] = val - - def dump_tabular(self): - # Create strings for printing. - key2str = OrderedDict() - for (key, val) in self.name2val.items(): - if hasattr(val, "__float__"): - valstr = "%-8.3g" % val - else: - valstr = val - key2str[self._truncate(key)] = self._truncate(valstr) - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) - # Write to all text outputs - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for (key, val) in key2str.items(): - self._write_text("| ", key, " " * (keywidth - len(key)), " | ", - val, " " * (valwidth - len(val)), " |\n") - self._write_text("-" * (keywidth + valwidth + 7), "\n") - for f in self.text_outputs: - try: - f.flush() - except OSError: - sys.stderr.write('Warning! OSError when flushing.\n') - # Write to tensorboard - if self.tbwriter is not None: - self.tbwriter.write_values(self.name2val) - self.name2val.clear() - - def log(self, *args, **kwargs): - level = kwargs['level'] if 'level' in kwargs else INFO - if self.level <= level: - self._do_log(*args) - - # Configuration - - def set_level(self, level): - self.level = level - - def get_dir(self): - return self.dir - - def close(self): - for f in self.text_outputs[1:]: - f.close() - if self.tbwriter: - self.tbwriter.close() - - # Misc - - def _do_log(self, *args): - self._write_text(*args + ('\n', )) - for f in self.text_outputs: - try: - f.flush() - except OSError: - print('Warning! OSError when flushing.') - - def _write_text(self, *strings): - for f in self.text_outputs: - for string in strings: - f.write(string) - - def _truncate(self, s): - if len(s) > 33: - return s[:30] + "..." - else: - return s - - -_Logger.DEFAULT = _Logger() -_Logger.CURRENT = _Logger.DEFAULT diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index a303643f55ee..fe53a5fb559e 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -23,6 +23,7 @@ "max_sample_requests_in_flight_per_worker", ] +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # V-trace params (see vtrace.py). @@ -65,8 +66,8 @@ "vf_loss_coeff": 0.5, "entropy_coeff": -0.01, }) - # __sphinx_doc_end__ +# yapf: enable class ImpalaAgent(Agent): diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index edc24ca1b05b..8ef5170bbb72 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -8,6 +8,7 @@ from ray.rllib.utils import merge_dicts from ray.tune.trial import Resources +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # No remote workers by default @@ -15,8 +16,8 @@ # Learning rate "lr": 0.0004, }) - # __sphinx_doc_end__ +# yapf: enable class PGAgent(Agent): diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 0527bcddda7e..480bf66e447f 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -2,12 +2,17 @@ from __future__ import division from __future__ import print_function +import logging + from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.utils import merge_dicts from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer from ray.tune.trial import Resources +logger = logging.getLogger(__name__) + +# yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # If true, use the Generalized Advantage Estimator (GAE) @@ -55,8 +60,8 @@ # Use the sync samples optimizer instead of the multi-gpu one "simple_optimizer": False, }) - # __sphinx_doc_end__ +# yapf: enable class PPOAgent(Agent): @@ -111,7 +116,7 @@ def _validate_config(self): if waste_ratio > 1.5: raise ValueError(msg) else: - print("Warning: " + msg) + logger.warn(msg) if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: raise ValueError( "Minibatch size {} must be <= train batch size {}.".format( diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index db88eb759df7..0578c34173b6 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -3,6 +3,7 @@ from __future__ import print_function import gym +import logging import pickle import tensorflow as tf @@ -99,7 +100,8 @@ def __init__(self, model_config=None, policy_config=None, worker_index=0, - monitor_path=None): + monitor_path=None, + log_level=None): """Initialize a policy evaluator. Arguments: @@ -158,8 +160,12 @@ def __init__(self, through EnvContext so that envs can be configured per worker. monitor_path (str): Write out episode stats and videos to this directory if specified. + log_level (str): Set the root log level on creation. """ + if log_level: + logging.getLogger("ray.rllib").setLevel(log_level) + env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} self.policy_config = policy_config diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 2f90dcbabcc4..85d5386b1248 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -3,6 +3,7 @@ from __future__ import print_function from collections import defaultdict, namedtuple +import logging import numpy as np import six.moves.queue as queue import threading @@ -16,6 +17,8 @@ from ray.rllib.models.action_dist import TupleActions from ray.rllib.utils.tf_run_builder import TFRunBuilder +logger = logging.getLogger(__name__) + RolloutMetrics = namedtuple( "RolloutMetrics", ["episode_length", "episode_reward", "agent_rewards"]) @@ -221,7 +224,7 @@ def _env_runner(async_vector_env, horizon = ( async_vector_env.get_unwrapped()[0].spec.max_episode_steps) except Exception: - print("*** WARNING ***: no episode horizon specified, assuming inf") + logger.warn("no episode horizon specified, assuming inf") if not horizon: horizon = float("inf") diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 1c55cb79e1e7..4c0a20f77405 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -3,6 +3,7 @@ from __future__ import print_function import gym +import logging import numpy as np import tensorflow as tf from functools import partial @@ -21,6 +22,9 @@ from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM +logger = logging.getLogger(__name__) + +# yapf: disable # __sphinx_doc_begin__ MODEL_DEFAULTS = { # === Built-in options === @@ -67,8 +71,8 @@ # Extra options to pass to the custom classes "custom_options": {}, } - # __sphinx_doc_end__ +# yapf: enable class ModelCatalog(object): @@ -200,7 +204,7 @@ def _get_model(input_dict, obs_space, num_outputs, options, state_in, seq_lens): if options.get("custom_model"): model = options["custom_model"] - print("Using custom model {}".format(model)) + logger.info("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( input_dict, obs_space, @@ -238,7 +242,7 @@ def get_torch_model(input_shape, num_outputs, options=None): options = options or MODEL_DEFAULTS if options.get("custom_model"): model = options["custom_model"] - print("Using custom torch model {}".format(model)) + logger.info("Using custom torch model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( input_shape, num_outputs, options) @@ -271,7 +275,7 @@ def get_preprocessor(env, options=None): if options.get("custom_preprocessor"): preprocessor = options["custom_preprocessor"] - print("Using custom preprocessor {}".format(preprocessor)) + logger.info("Using custom preprocessor {}".format(preprocessor)) return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)( env.observation_space, options) diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index b7e084062158..8144b5706d17 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -1,13 +1,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import cv2 +import logging import numpy as np import gym ATARI_OBS_SHAPE = (210, 160, 3) ATARI_RAM_OBS_SHAPE = (128, ) +logger = logging.getLogger(__name__) + class Preprocessor(object): """Defines an abstract observation preprocessor function. @@ -128,7 +132,7 @@ def _init_shape(self, obs_space, options): self.preprocessors = [] for i in range(len(self._obs_space.spaces)): space = self._obs_space.spaces[i] - print("Creating sub-preprocessor for", space) + logger.info("Creating sub-preprocessor for {}".format(space)) preprocessor = get_preprocessor(space)(space, self._options) self.preprocessors.append(preprocessor) size += preprocessor.size @@ -153,7 +157,7 @@ def _init_shape(self, obs_space, options): size = 0 self.preprocessors = [] for space in self._obs_space.spaces.values(): - print("Creating sub-preprocessor for", space) + logger.info("Creating sub-preprocessor for {}".format(space)) preprocessor = get_preprocessor(space)(space, self._options) self.preprocessors.append(preprocessor) size += preprocessor.size diff --git a/python/ray/rllib/models/pytorch/fcnet.py b/python/ray/rllib/models/pytorch/fcnet.py index e8f50da2fb34..f69cb7ca21d4 100644 --- a/python/ray/rllib/models/pytorch/fcnet.py +++ b/python/ray/rllib/models/pytorch/fcnet.py @@ -2,10 +2,14 @@ from __future__ import division from __future__ import print_function +import logging + from ray.rllib.models.pytorch.model import Model, SlimFC from ray.rllib.models.pytorch.misc import normc_initializer import torch.nn as nn +logger = logging.getLogger(__name__) + class FullyConnectedNetwork(Model): """TODO(rliaw): Logits, Value should both be contained here""" @@ -19,7 +23,7 @@ def _build_layers(self, inputs, num_outputs, options): activation = nn.Tanh elif fcnet_activation == "relu": activation = nn.ReLU - print("Constructing fcnet {} {}".format(hiddens, activation)) + logger.info("Constructing fcnet {} {}".format(hiddens, activation)) layers = [] last_layer_size = inputs diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index 69f5e849b542..5ad6bd809b70 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function +import logging import numpy as np import random import time @@ -20,6 +21,8 @@ from ray.rllib.utils.timer import TimerStat from ray.rllib.utils.window_stat import WindowStat +logger = logging.getLogger(__name__) + LEARNER_QUEUE_MAX_SIZE = 16 NUM_DATA_LOAD_THREADS = 16 @@ -84,7 +87,7 @@ def __init__(self, self.devices = ["/cpu:0"] else: self.devices = ["/gpu:{}".format(i) for i in range(num_gpus)] - print("TFMultiGPULearner devices", self.devices) + logger.info("TFMultiGPULearner devices {}".format(self.devices)) assert self.train_batch_size % len(self.devices) == 0 assert self.train_batch_size >= len(self.devices), "batch too small" self.policy = self.local_evaluator.policy_map["default"] @@ -199,7 +202,7 @@ def _init(self, self.sample_batch_size = sample_batch_size if num_gpus > 1 or num_parallel_data_loaders > 1: - print( + logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format( num_gpus, num_parallel_data_loaders)) if train_batch_size // max(1, num_gpus) % ( diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 4595415a1eb0..7e01ee9041dc 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import logging import numpy as np from collections import defaultdict import tensorflow as tf @@ -12,6 +13,8 @@ from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.timer import TimerStat +logger = logging.getLogger(__name__) + class LocalMultiGPUOptimizer(PolicyOptimizer): """A synchronous optimizer that uses multiple local GPUs. @@ -53,7 +56,7 @@ def _init(self, self.update_weights_timer = TimerStat() self.standardize_fields = standardize_fields - print("LocalMultiGPUOptimizer devices", self.devices) + logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices)) if set(self.local_evaluator.policy_map.keys()) != {"default"}: raise ValueError( @@ -126,7 +129,7 @@ def step(self): with self.grad_timer: num_batches = ( int(tuples_per_device) // int(self.per_device_batch_size)) - print("== sgd epochs ==") + logger.debug("== sgd epochs ==") for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) @@ -136,7 +139,7 @@ def step(self): permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) - print(i, _averaged(iter_extra_fetches)) + logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 21fcf5f0b7a7..9d83140e9bce 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -2,11 +2,15 @@ from __future__ import division from __future__ import print_function +import logging + import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes from ray.rllib.evaluation.sample_batch import MultiAgentBatch +logger = logging.getLogger(__name__) + class PolicyOptimizer(object): """Policy optimizers encapsulate distributed RL optimization strategies. diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index 20922ff54036..38d5269f0039 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -3,11 +3,14 @@ from __future__ import print_function import ray +import logging from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.sample_batch import SampleBatch from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat +logger = logging.getLogger(__name__) + class SyncSamplesOptimizer(PolicyOptimizer): """A simple synchronous RL optimizer. @@ -52,7 +55,7 @@ def step(self): if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: - print(i, fetches) + logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index e865feb431b4..487c3595eac5 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -2,9 +2,12 @@ from __future__ import division from __future__ import print_function +import logging import os import ray +logger = logging.getLogger(__name__) + class TaskPool(object): """Helper class for tracking the status of many in-flight actor tasks.""" @@ -80,11 +83,12 @@ def split_colocated(actors): def try_create_colocated(cls, args, count): actors = [cls.remote(*args) for _ in range(count)] local, _ = split_colocated(actors) - print("Got {} colocated actors of {}".format(len(local), count)) + logger.info("Got {} colocated actors of {}".format(len(local), count)) return local def create_colocated(cls, args, count): + logger.info("Trying to create {} colocated actors".format(count)) ok = [] i = 1 while len(ok) < count and i < 10: diff --git a/python/ray/rllib/utils/compression.py b/python/ray/rllib/utils/compression.py index 5f28455ee44a..aed0dd598560 100644 --- a/python/ray/rllib/utils/compression.py +++ b/python/ray/rllib/utils/compression.py @@ -2,18 +2,21 @@ from __future__ import division from __future__ import print_function +import logging import time import base64 import numpy as np import pyarrow +logger = logging.getLogger(__name__) + try: import lz4.frame LZ4_ENABLED = True except ImportError: - print("WARNING: lz4 not available, disabling sample compression. " - "This will significantly impact RLlib performance. " - "To install lz4, run `pip install lz4`.") + logger.warn("lz4 not available, disabling sample compression. " + "This will significantly impact RLlib performance. " + "To install lz4, run `pip install lz4`.") LZ4_ENABLED = False diff --git a/python/ray/rllib/utils/policy_client.py b/python/ray/rllib/utils/policy_client.py index 901dc983b098..1bb4b5e13404 100644 --- a/python/ray/rllib/utils/policy_client.py +++ b/python/ray/rllib/utils/policy_client.py @@ -2,14 +2,17 @@ from __future__ import division from __future__ import print_function +import logging import pickle +logger = logging.getLogger(__name__) + try: import requests # `requests` is not part of stdlib. except ImportError: requests = None - print("Couldn't import `requests` library. Be sure to install it on" - " the client side.") + logger.warn("Couldn't import `requests` library. Be sure to install it on" + " the client side.") class PolicyClient(object): @@ -109,8 +112,7 @@ def _send(self, data): payload = pickle.dumps(data) response = requests.post(self._address, data=payload) if response.status_code != 200: - print("Request failed", data) - print(response.text) + logger.error("Request failed {}: {}".format(response.text, data)) response.raise_for_status() parsed = pickle.loads(response.content) return parsed diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index 2ea3ba7b8042..ce1c58279e94 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -2,12 +2,15 @@ from __future__ import division from __future__ import print_function +import logging import os import time import tensorflow as tf from tensorflow.python.client import timeline +logger = logging.getLogger(__name__) + class TFRunBuilder(object): """Used to incrementally build up a TensorFlow run. @@ -43,7 +46,7 @@ def get(self, to_fetch): self.session, self.fetches, self.debug_name, self.feed_dict, os.environ.get("TF_TIMELINE_DIR")) except Exception as e: - print("Error fetching: {}, feed_dict={}".format( + logger.error("Error fetching: {}, feed_dict={}".format( self.fetches, self.feed_dict)) raise e if isinstance(to_fetch, int): @@ -76,8 +79,8 @@ def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): debug_name, os.getpid(), _count)) _count += 1 trace_file = open(outf, "w") - print("Wrote tf timeline ({} s) to {}".format(time.time() - start, - os.path.abspath(outf))) + logger.info("Wrote tf timeline ({} s) to {}".format( + time.time() - start, os.path.abspath(outf))) trace_file.write(trace.generate_chrome_trace_format()) else: fetches = sess.run(ops, feed_dict=feed_dict) From eff7cb4458c74fdbb02562c017b2949f147de5bb Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 22 Oct 2018 12:17:13 -0700 Subject: [PATCH 047/215] [tune] Fix SearchAlg finishing early (#3081) * Fix trial search alg finishing early * Fix lint * fix lint * nit fix --- doc/source/tune-usage.rst | 2 +- python/ray/tune/test/trial_runner_test.py | 28 ++++++++++++++++++++++- python/ray/tune/trial_runner.py | 4 ++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index b52241f039ed..0748b1cac688 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -223,7 +223,7 @@ For TensorFlow model training, this would look something like this `(full tensor .. code-block:: python class MyClass(Trainable): - def _setup(self): + def _setup(self, config): self.saver = tf.train.Saver() self.sess = ... self.iteration = 0 diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 65b8fbe36f62..450e96136fa2 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -20,7 +20,8 @@ from ray.tune.trial import Trial, Resources from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import grid_search, BasicVariantGenerator -from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm +from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm, + SuggestionAlgorithm) from ray.tune.suggest.variant_generator import RecursiveDependencyError @@ -1385,6 +1386,31 @@ def testSearchAlgStalled(self): self.assertTrue(searcher.is_finished()) self.assertTrue(runner.is_finished()) + def testSearchAlgFinishes(self): + """SearchAlg changing state in `next_trials` does not crash.""" + + class FinishFastAlg(SuggestionAlgorithm): + def next_trials(self): + self._finished = True + return [] + + ray.init(num_cpus=4, num_gpus=2) + experiment_spec = { + "run": "__fake", + "num_samples": 3, + "stop": { + "training_iteration": 1 + } + } + searcher = FinishFastAlg() + experiments = [Experiment.from_json("test", experiment_spec)] + searcher.add_configurations(experiments) + + runner = TrialRunner(search_alg=searcher) + runner.step() # This should not fail + self.assertTrue(searcher.is_finished()) + self.assertTrue(runner.is_finished()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 6423d6a95f10..8e3eb861246f 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -114,6 +114,10 @@ def step(self): self.trial_executor.start_trial(next_trial) elif self.trial_executor.get_running_trials(): self._process_events() + elif self.is_finished(): + # We check `is_finished` again here because the experiment + # may have finished while getting the next trial. + pass else: for trial in self._trials: if trial.status == Trial.PENDING: From 8d8b6e5bfab1993718580205612e5397215a1bd6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Oct 2018 22:31:13 -0700 Subject: [PATCH 048/215] Retry connections to redis for async and subscribe contexts (#3105) This is fixing a problem that @devin-petersohn observed on the windows subsystem for linux. In theory, redis should be up once the async connect is happening and there should be no retries needed for the async connect. However on the windows subsystem for linux, the async connect was failing even though the synchronous one was working. Maybe windows has a different semantics here than linux. --- src/ray/gcs/redis_context.cc | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 1a8111963256..f9e24cd93775 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -157,27 +157,35 @@ Status AuthenticateRedis(redisAsyncContext *context, const std::string &password return Status::OK(); } -Status RedisContext::Connect(const std::string &address, int port, bool sharding, - const std::string &password = "") { +template +Status ConnectWithRetries(const std::string &address, int port, + const RedisConnectFunction &connect_function, + RedisContext **context) { int connection_attempts = 0; - context_ = redisConnect(address.c_str(), port); - while (context_ == nullptr || context_->err) { + *context = connect_function(address.c_str(), port); + while (*context == nullptr || (*context)->err) { if (connection_attempts >= RayConfig::instance().redis_db_connect_retries()) { - if (context_ == nullptr) { + if (*context == nullptr) { RAY_LOG(FATAL) << "Could not allocate redis context."; } - if (context_->err) { + if ((*context)->err) { RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":" - << port; + << port << " (context.err = " << (*context)->err << ")"; } break; } RAY_LOG(WARNING) << "Failed to connect to Redis, retrying."; // Sleep for a little. usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - context_ = redisConnect(address.c_str(), port); + *context = connect_function(address.c_str(), port); connection_attempts += 1; } + return Status::OK(); +} + +Status RedisContext::Connect(const std::string &address, int port, bool sharding, + const std::string &password = "") { + RAY_CHECK_OK(ConnectWithRetries(address, port, redisConnect, &context_)); RAY_CHECK_OK(AuthenticateRedis(context_, password)); redisReply *reply = reinterpret_cast( @@ -186,19 +194,11 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding freeReplyObject(reply); // Connect to async context - async_context_ = redisAsyncConnect(address.c_str(), port); - if (async_context_ == nullptr || async_context_->err) { - RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":" - << port; - } + RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &async_context_)); RAY_CHECK_OK(AuthenticateRedis(async_context_, password)); // Connect to subscribe context - subscribe_context_ = redisAsyncConnect(address.c_str(), port); - if (subscribe_context_ == nullptr || subscribe_context_->err) { - RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " << address - << ":" << port; - } + RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &subscribe_context_)); RAY_CHECK_OK(AuthenticateRedis(subscribe_context_, password)); return Status::OK(); From 73a092e08cfbbfd3838c1794ffe686a166ee19e8 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 22 Oct 2018 22:55:43 -0700 Subject: [PATCH 049/215] update (#3112) --- python/ray/rllib/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index c082c1d28290..0517e2b09f1d 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -214,7 +214,7 @@ def __init__(self, config=None, env=None, logger_creator=None): if logger_creator is None: timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") logdir_prefix = "{}_{}_{}".format( - [self._agent_name, self._env_id, timestr]) + self._agent_name, self._env_id, timestr) def default_logger_creator(config): """Creates a Unified logger with a default logdir prefix From 22dd7e0428adf704f79d023c8b584e3d641b3070 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 22 Oct 2018 23:16:55 -0700 Subject: [PATCH 050/215] Add test for wait reconstruction. (#3110) --- test/runtest.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/runtest.py b/test/runtest.py index 4ddeb57f84c7..a9efb2fb3925 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -2551,3 +2551,22 @@ def test_initialized_local_mode(shutdown_only_with_initialization_check): assert not ray.is_initialized() ray.init(num_cpus=0, local_mode=True) assert ray.is_initialized() + + +@pytest.mark.skipif( + os.environ.get("RAY_USE_XRAY") != "1", + reason="This test only works with xray.") +def test_wait_reconstruction(shutdown_only): + ray.init(num_cpus=1, object_store_memory=10**8) + + @ray.remote + def f(): + return np.zeros(6 * 10**7, dtype=np.uint8) + + x_id = f.remote() + ray.wait([x_id]) + ray.wait([f.remote()]) + assert not ray.worker.global_worker.plasma_client.contains( + ray.pyarrow.plasma.ObjectID(x_id.id())) + ready_ids, _ = ray.wait([x_id]) + assert len(ready_ids) == 1 From 9d2e864cafcea90f70fa90f10d7544995a60efaf Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 22 Oct 2018 23:41:42 -0700 Subject: [PATCH 051/215] Fix Python linting error. (#3113) --- python/ray/rllib/agents/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 0517e2b09f1d..7169be3a63ff 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -213,8 +213,8 @@ def __init__(self, config=None, env=None, logger_creator=None): # Create a default logger creator if no logger_creator is specified if logger_creator is None: timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") - logdir_prefix = "{}_{}_{}".format( - self._agent_name, self._env_id, timestr) + logdir_prefix = "{}_{}_{}".format(self._agent_name, self._env_id, + timestr) def default_logger_creator(config): """Creates a Unified logger with a default logdir prefix From 9c1826ed69a15240aaef960f20f9cbd290394478 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 23 Oct 2018 12:46:39 -0700 Subject: [PATCH 052/215] Use XRay backend by default. (#3020) * Use XRay backend by default. * Remove irrelevant valgrind tests. * Fix * Move tests around. * Fix * Fix test * Fix test. * String/unicode fix. * Fix test * Fix unicode issue. * Minor changes * Fix bug in test_global_state.py. * Fix test. * Linting * Try arrow change and other object manager changes. * Use newer plasma client API * Small updates * Revert plasma client api change. * Update * Update arrow and allow SendObjectHeaders to fail. * Update arrow * Update python/ray/experimental/state.py Co-Authored-By: robertnishihara * Address comments. --- .travis.yml | 38 +++--- cmake/Modules/ArrowExternalProject.cmake | 4 +- python/ray/experimental/state.py | 20 ++-- python/ray/function_manager.py | 9 ++ python/ray/global_scheduler/test/test.py | 2 +- python/ray/plasma/test/test.py | 5 +- python/ray/profiling.py | 8 +- python/ray/scripts/scripts.py | 18 ++- python/ray/services.py | 25 ++-- python/ray/test/test_global_state.py | 70 +++++------ python/ray/test/test_ray_init.py | 10 +- python/ray/worker.py | 33 +++-- src/ray/object_manager/connection_pool.cc | 2 +- src/ray/object_manager/object_buffer_pool.cc | 1 + src/ray/object_manager/object_directory.cc | 5 +- src/ray/object_manager/object_manager.cc | 21 +++- src/ray/raylet/node_manager.cc | 10 +- test/actor_test.py | 24 ++-- test/component_failures_test.py | 28 ++--- test/failure_test.py | 10 +- test/jenkins_tests/run_multi_node_tests.sh | 120 +++++++++---------- test/multi_node_test.py | 2 +- test/runtest.py | 18 +-- test/stress_tests.py | 7 +- 24 files changed, 263 insertions(+), 227 deletions(-) diff --git a/.travis.yml b/.travis.yml index 70f548f83bf8..3633198f75b6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -69,14 +69,14 @@ matrix: script: - cd build - - bash ../src/common/test/run_valgrind.sh - - bash ../src/plasma/test/run_valgrind.sh - - bash ../src/local_scheduler/test/run_valgrind.sh + # - bash ../src/common/test/run_valgrind.sh + # - bash ../src/plasma/test/run_valgrind.sh + # - bash ../src/local_scheduler/test/run_valgrind.sh - bash ../src/ray/test/run_object_manager_valgrind.sh - cd .. - - python ./python/ray/plasma/test/test.py valgrind - - python ./python/ray/local_scheduler/test/test.py valgrind + # - python ./python/ray/plasma/test/test.py valgrind + # - python ./python/ray/local_scheduler/test/test.py valgrind # - python ./python/ray/global_scheduler/test/test.py valgrind # Build Linux wheels. @@ -107,16 +107,23 @@ matrix: env: - PYTHON=3.5 - RAY_USE_NEW_GCS=on - - RAY_USE_XRAY=1 + # Test legacy Ray. - os: linux dist: trusty - env: PYTHON=3.5 RAY_USE_XRAY=1 + env: PYTHON=3.5 RAY_USE_XRAY=0 install: - ./.travis/install-dependencies.sh - export PATH="$HOME/miniconda/bin:$PATH" - ./.travis/install-ray.sh - ./.travis/install-cython-examples.sh + + - cd build + - bash ../src/common/test/run_tests.sh + - bash ../src/plasma/test/run_tests.sh + - bash ../src/local_scheduler/test/run_tests.sh + - cd .. + script: - export PATH="$HOME/miniconda/bin:$PATH" # The following is needed so cloudpickle can find some of the @@ -128,8 +135,8 @@ matrix: - python -m pytest -v python/ray/common/test/test.py - python -m pytest -v python/ray/common/redis_module/runtest.py - python -m pytest -v python/ray/plasma/test/test.py - # - python -m pytest -v python/ray/local_scheduler/test/test.py - # - python -m pytest -v python/ray/global_scheduler/test/test.py + - python -m pytest -v python/ray/local_scheduler/test/test.py + - python -m pytest -v python/ray/global_scheduler/test/test.py - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py @@ -190,9 +197,6 @@ install: - ./src/ray/util/logging_test --gtest_filter=PrintLogTest* - ./src/ray/util/signal_test - - bash ../src/common/test/run_tests.sh - - bash ../src/plasma/test/run_tests.sh - - bash ../src/local_scheduler/test/run_tests.sh - cd .. script: @@ -203,11 +207,11 @@ script: # module is only found if the test directory is in the PYTHONPATH. - export PYTHONPATH="$PYTHONPATH:./test/" - - python -m pytest -v python/ray/common/test/test.py - - python -m pytest -v python/ray/common/redis_module/runtest.py - - python -m pytest -v python/ray/plasma/test/test.py - - python -m pytest -v python/ray/local_scheduler/test/test.py - - python -m pytest -v python/ray/global_scheduler/test/test.py + # - python -m pytest -v python/ray/common/test/test.py + # - python -m pytest -v python/ray/common/redis_module/runtest.py + # - python -m pytest -v python/ray/plasma/test/test.py + # - python -m pytest -v python/ray/local_scheduler/test/test.py + # - python -m pytest -v python/ray/global_scheduler/test/test.py - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index 07c48d97f0cc..08d250b81bfd 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -15,10 +15,10 @@ # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2664. We +# The PR for this commit is https://github.com/apache/arrow/pull/2792. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 3545186d6997b943ffc3d79634f2d08eefbd7322) +set(arrow_TAG 2d0d3d0dc51999fbaafb15d8b8362a1ef3de2ef7) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 906d650d2866..99c3e3ba2202 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -132,13 +132,15 @@ def _initialize_global_state(self, use_raylet = self.redis_client.get("UseRaylet") if use_raylet is not None: - self.use_raylet = int(use_raylet) == 1 - elif os.environ.get("RAY_USE_XRAY") == "1": + self.use_raylet = bool(int(use_raylet)) + elif os.environ.get("RAY_USE_XRAY") == "0": # This environment variable is used in our testing setup. - print("Detected environment variable 'RAY_USE_XRAY'.") - self.use_raylet = True - else: + print("Detected environment variable 'RAY_USE_XRAY' with value " + "{}. This turns OFF xray.".format( + os.environ.get("RAY_USE_XRAY"))) self.use_raylet = False + else: + self.use_raylet = True # Get the rest of the information. self.redis_clients = [] @@ -1310,8 +1312,10 @@ def cluster_resources(self): else: clients = self.client_table() for client in clients: - for key, value in client["Resources"].items(): - resources[key] += value + # Only count resources from live clients. + if client["IsInsertion"]: + for key, value in client["Resources"].items(): + resources[key] += value return dict(resources) @@ -1379,8 +1383,6 @@ def available_resources(self): if local_scheduler_id not in local_scheduler_ids: del available_resources_by_id[local_scheduler_id] else: - # TODO(rliaw): Is this a fair assumption? - # Assumes the number of Redis clients does not change subscribe_clients = [ redis_client.pubsub(ignore_subscribe_messages=True) for redis_client in self.redis_clients diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 0e123bd67ead..f9cb1a08b735 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -5,6 +5,7 @@ import hashlib import inspect import json +import sys import time import traceback from collections import ( @@ -342,6 +343,14 @@ def fetch_and_register_actor(self, actor_class_key): checkpoint_interval = int(checkpoint_interval) actor_method_names = json.loads(decode(actor_method_names)) + # In Python 2, json loads strings as unicode, so convert them back to + # strings. + if sys.version_info < (3, 0): + actor_method_names = [ + method_name.encode("ascii") + for method_name in actor_method_names + ] + # Create a temporary actor with some temporary methods so that if # the actor fails to be unpickled, the temporary actor can be used # (just to produce error messages and to prevent the driver from diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 37aad62ee1b0..0e262e705c4d 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -55,7 +55,7 @@ def setUp(self): # Start one Redis server and N pairs of (plasma, local_scheduler) self.node_ip_address = "127.0.0.1" redis_address, redis_shards = services.start_redis( - self.node_ip_address) + self.node_ip_address, use_raylet=False) redis_port = services.get_port(redis_address) time.sleep(0.1) # Create a client for the global state store. diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index a67f2d255e3a..bc2418f005d2 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -125,7 +125,7 @@ def setUp(self): store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND) store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND) # Start a Redis server. - redis_address, _ = services.start_redis("127.0.0.1") + redis_address, _ = services.start_redis("127.0.0.1", use_raylet=False) # Start two PlasmaManagers. manager_name1, self.p4, self.port1 = ray.plasma.start_plasma_manager( store_name1, redis_address, use_valgrind=USE_VALGRIND) @@ -483,7 +483,8 @@ def setUp(self): self.store_name, self.p2 = start_plasma_store( use_valgrind=USE_VALGRIND) # Start a Redis server. - self.redis_address, _ = services.start_redis("127.0.0.1") + self.redis_address, _ = services.start_redis( + "127.0.0.1", use_raylet=False) # Start a PlasmaManagers. manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager( self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) diff --git a/python/ray/profiling.py b/python/ray/profiling.py index e4c2d438fc2a..a16dd9d7ad95 100644 --- a/python/ray/profiling.py +++ b/python/ray/profiling.py @@ -230,8 +230,9 @@ def set_attribute(self, key, value): value: The attribute value. """ if not isinstance(key, str) or not isinstance(value, str): - raise ValueError("The extra_data argument must be a " - "dictionary mapping strings to strings.") + raise ValueError("The arguments 'key' and 'value' must both be " + "strings. Instead they are {} and {}.".format( + key, value)) self.extra_data[key] = value def __enter__(self): @@ -250,7 +251,8 @@ def __exit__(self, type, value, tb): for key, value in self.extra_data.items(): if not isinstance(key, str) or not isinstance(value, str): raise ValueError("The extra_data argument must be a " - "dictionary mapping strings to strings.") + "dictionary mapping strings to strings. " + "Instead it is {}.".format(self.extra_data)) if type is not None: extra_data = json.dumps({ diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index b28fc4e179e4..c4fd13b67cf2 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -169,9 +169,9 @@ def cli(logging_level, logging_format): help="the file that contains the autoscaling config") @click.option( "--use-raylet", - is_flag=True, default=None, - help="use the raylet code path") + type=bool, + help="use the raylet code path, this defaults to false") @click.option( "--no-redirect-worker-output", is_flag=True, @@ -207,10 +207,16 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, if redis_address is not None: redis_address = services.address_to_ip(redis_address) - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True + if use_raylet is None: + if os.environ.get("RAY_USE_XRAY") == "0": + # This environment variable is used in our testing setup. + logger.info("Detected environment variable 'RAY_USE_XRAY' with " + "value {}. This turns OFF xray.".format( + os.environ.get("RAY_USE_XRAY"))) + use_raylet = False + else: + use_raylet = True + if not use_raylet and redis_password is not None: raise Exception("Setting the 'redis-password' argument is not " "supported in legacy Ray. To run Ray with " diff --git a/python/ray/services.py b/python/ray/services.py index a887b274386f..c78e7556b990 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -430,7 +430,7 @@ def start_redis(node_ip_address, redis_shard_ports=None, num_redis_shards=1, redis_max_clients=None, - use_raylet=False, + use_raylet=True, redirect_output=False, redirect_worker_output=False, cleanup=True, @@ -450,8 +450,7 @@ def start_redis(node_ip_address, shard. redis_max_clients: If this is provided, Ray will attempt to configure Redis with this maxclients number. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + use_raylet: True if the new raylet code path should be used. redirect_output (bool): True if output should be redirected to a file and false otherwise. redirect_worker_output (bool): True if worker output should be @@ -1100,7 +1099,7 @@ def start_plasma_store(node_ip_address, cleanup=True, plasma_directory=None, huge_pages=False, - use_raylet=False, + use_raylet=True, plasma_store_socket_name=None, redis_password=None): """This method starts an object store process. @@ -1130,8 +1129,7 @@ def start_plasma_store(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + use_raylet: True if the new raylet code path should be used. redis_password (str): The password of the redis server. Return: @@ -1359,7 +1357,7 @@ def start_ray_processes(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False, + use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1417,8 +1415,7 @@ def start_ray_processes(address_info=None, huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1692,7 +1689,7 @@ def start_ray_node(node_ip_address, resources=None, plasma_directory=None, huge_pages=False, - use_raylet=False, + use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1730,8 +1727,7 @@ def start_ray_node(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1788,7 +1784,7 @@ def start_ray_head(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=False, + use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1840,8 +1836,7 @@ def start_ray_head(address_info=None, huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. This is - not supported yet. + use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path diff --git a/python/ray/test/test_global_state.py b/python/ray/test/test_global_state.py index 7b12ee022790..c5501dc9c525 100644 --- a/python/ray/test/test_global_state.py +++ b/python/ray/test/test_global_state.py @@ -2,57 +2,57 @@ from __future__ import division from __future__ import print_function +import pytest import time import ray -def setup_module(): - if not ray.worker.global_worker.connected: - ray.init(num_cpus=1) +@pytest.fixture +def ray_start(): + # Start the Ray processes. + ray.init(num_cpus=1) + yield None + # The code after the yield will run as teardown code. + ray.shutdown() - # Finish initializing Ray. Otherwise available_resources() does not - # reflect resource use of submitted tasks - ray.get(cpu_task.remote(0)) +def test_replenish_resources(ray_start): + cluster_resources = ray.global_state.cluster_resources() + available_resources = ray.global_state.available_resources() + assert cluster_resources == available_resources -@ray.remote(num_cpus=1) -def cpu_task(seconds): - time.sleep(seconds) + @ray.remote + def cpu_task(): + pass + ray.get(cpu_task.remote()) + start = time.time() + resources_reset = False -class TestAvailableResources(object): timeout = 10 - - def test_no_tasks(self): - cluster_resources = ray.global_state.cluster_resources() + while not resources_reset and time.time() - start < timeout: available_resources = ray.global_state.available_resources() - assert cluster_resources == available_resources - - def test_replenish_resources(self): - cluster_resources = ray.global_state.cluster_resources() + resources_reset = (cluster_resources == available_resources) - ray.get(cpu_task.remote(0)) - start = time.time() - resources_reset = False + assert resources_reset - while not resources_reset and time.time() - start < self.timeout: - available_resources = ray.global_state.available_resources() - resources_reset = (cluster_resources == available_resources) - assert resources_reset +def test_uses_resources(ray_start): + cluster_resources = ray.global_state.cluster_resources() - def test_uses_resources(self): - cluster_resources = ray.global_state.cluster_resources() - task_id = cpu_task.remote(1) - start = time.time() - resource_used = False + @ray.remote + def cpu_task(): + time.sleep(1) - while not resource_used and time.time() - start < self.timeout: - available_resources = ray.global_state.available_resources() - resource_used = available_resources[ - "CPU"] == cluster_resources["CPU"] - 1 + cpu_task.remote() + resource_used = False - assert resource_used + start = time.time() + timeout = 10 + while not resource_used and time.time() - start < timeout: + available_resources = ray.global_state.available_resources() + resource_used = available_resources[ + "CPU"] == cluster_resources["CPU"] - 1 - ray.get(task_id) # clean up to reset resources + assert resource_used diff --git a/python/ray/test/test_ray_init.py b/python/ray/test/test_ray_init.py index 62d581003b2e..a64dd4a94256 100644 --- a/python/ray/test/test_ray_init.py +++ b/python/ray/test/test_ray_init.py @@ -25,19 +25,11 @@ def shutdown_only(): class TestRedisPassword(object): - @pytest.mark.skipif( - os.environ.get("RAY_USE_NEW_GCS") != "on" - and os.environ.get("RAY_USE_XRAY"), - reason="Redis authentication works for raylet and old GCS.") - def test_exceptions(self, password, shutdown_only): - with pytest.raises(Exception): - ray.init(redis_password=password) - @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="New GCS API doesn't support Redis authentication yet.") @pytest.mark.skipif( - not os.environ.get("RAY_USE_XRAY"), + os.environ.get("RAY_USE_XRAY") == "0", reason="Redis authentication is not supported in legacy Ray.") def test_redis_password(self, password, shutdown_only): # Workaround for https://github.com/ray-project/ray/issues/3045 diff --git a/python/ray/worker.py b/python/ray/worker.py index e19c433753b7..bce109cd8a34 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1223,7 +1223,7 @@ def actor_handle_deserializer(serialized_obj): def get_address_info_from_redis_helper(redis_address, node_ip_address, - use_raylet=False, + use_raylet=True, redis_password=None): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as @@ -1333,7 +1333,7 @@ def get_address_info_from_redis_helper(redis_address, def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - use_raylet=False, + use_raylet=True, redis_password=None): counter = 0 while True: @@ -1497,10 +1497,15 @@ def _init(address_info=None, else: driver_mode = SCRIPT_MODE - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True + if use_raylet is None: + if os.environ.get("RAY_USE_XRAY") == "0": + # This environment variable is used in our testing setup. + logger.info("Detected environment variable 'RAY_USE_XRAY' with " + "value {}. This turns OFF xray.".format( + os.environ.get("RAY_USE_XRAY"))) + use_raylet = False + else: + use_raylet = True # Get addresses of existing services. if address_info is None: @@ -1762,10 +1767,16 @@ def init(redis_address=None, else: raise Exception("Perhaps you called ray.init twice by accident?") - if use_raylet is None and os.environ.get("RAY_USE_XRAY") == "1": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY'.") - use_raylet = True + if use_raylet is None: + if os.environ.get("RAY_USE_XRAY") == "0": + # This environment variable is used in our testing setup. + logger.info("Detected environment variable 'RAY_USE_XRAY' with " + "value {}. This turns OFF xray.".format( + os.environ.get("RAY_USE_XRAY"))) + use_raylet = False + else: + use_raylet = True + if not use_raylet and redis_password is not None: raise Exception("Setting the 'redis_password' argument is not " "supported in legacy Ray. To run Ray with " @@ -1993,7 +2004,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, - use_raylet=False, + use_raylet=True, redis_password=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. diff --git a/src/ray/object_manager/connection_pool.cc b/src/ray/object_manager/connection_pool.cc index 58508f977b9b..2104eaa4a93c 100644 --- a/src/ray/object_manager/connection_pool.cc +++ b/src/ray/object_manager/connection_pool.cc @@ -79,7 +79,7 @@ void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id auto &connections = it->second; int64_t pos = std::find(connections.begin(), connections.end(), conn) - connections.begin(); - if (pos >= (int64_t)connections.size()) { + if (pos >= static_cast(connections.size())) { return; } connections.erase(connections.begin() + pos); diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index fdb3623592eb..6651f204288b 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -150,6 +150,7 @@ void ObjectBufferPool::SealChunk(const ObjectID &object_id, const uint64_t chunk CreateChunkState::REFERENCED); create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::SEALED; create_buffer_state_[object_id].num_seals_remaining--; + RAY_CHECK(create_buffer_state_[object_id].num_seals_remaining >= 0); RAY_LOG(DEBUG) << "SealChunk" << object_id << " " << create_buffer_state_[object_id].num_seals_remaining; if (create_buffer_state_[object_id].num_seals_remaining == 0) { diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 38c7cdd45ca3..a3e43ebbbfd8 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -110,8 +110,9 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id, if (result_client_id == ClientID::nil() || !data.is_insertion) { fail_callback(); } else { - const auto &info = RemoteConnectionInfo(client_id, data.node_manager_address, - (uint16_t)data.object_manager_port); + const auto &info = + RemoteConnectionInfo(client_id, data.node_manager_address, + static_cast(data.object_manager_port)); success_callback(info); } return ray::Status::OK(); diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 84be1106602b..e73565020b7b 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -97,6 +97,7 @@ void ObjectManager::StopIOService() { void ObjectManager::HandleObjectAdded(const ObjectInfoT &object_info) { // Notify the object directory that the object has been added to this node. ObjectID object_id = ObjectID::from_binary(object_info.object_id); + RAY_CHECK(local_objects_.count(object_id) == 0); local_objects_[object_id] = object_info; ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_, object_info); @@ -122,7 +123,9 @@ void ObjectManager::HandleObjectAdded(const ObjectInfoT &object_info) { } void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) { - local_objects_.erase(object_id); + auto it = local_objects_.find(object_id); + RAY_CHECK(it != local_objects_.end()); + local_objects_.erase(it); ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_); } @@ -352,6 +355,10 @@ void ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) { for (uint64_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { send_service_.post([this, client_id, object_id, data_size, metadata_size, chunk_index, info]() { + // NOTE: When this callback executes, it's possible that the object + // will have already been evicted. It's also possible that the + // object could be in the process of being transferred to this + // object manager from another object manager. ExecuteSendObject(client_id, object_id, data_size, metadata_size, chunk_index, info); }); @@ -398,15 +405,19 @@ ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id, // Fail on status not okay. The object is local, and there is // no other anticipated error here. - RAY_CHECK_OK(chunk_status.second); + ray::Status status = chunk_status.second; + if (!chunk_status.second.ok()) { + RAY_LOG(WARNING) << "Attempting to push object " << object_id + << " which is not local. It may have been evicted."; + RAY_RETURN_NOT_OK(status); + } // Create buffer. flatbuffers::FlatBufferBuilder fbb; - // TODO(hme): use to_flatbuf auto message = object_manager_protocol::CreatePushRequestMessage( - fbb, fbb.CreateString(object_id.binary()), chunk_index, data_size, metadata_size); + fbb, to_flatbuf(fbb, object_id), chunk_index, data_size, metadata_size); fbb.Finish(message); - ray::Status status = conn->WriteMessage( + status = conn->WriteMessage( static_cast(object_manager_protocol::MessageType::PushRequest), fbb.GetSize(), fbb.GetBufferPointer()); if (!status.ok()) { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 0db61e7ddc2d..8bbc28273c11 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -401,7 +401,7 @@ void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &cl } // Locate the client id in remote client table and update available resources based on // the received heartbeat information. - auto it = this->cluster_resource_map_.find(client_id); + auto it = cluster_resource_map_.find(client_id); if (it == cluster_resource_map_.end()) { // Haven't received the client registration for this client yet, skip this heartbeat. RAY_LOG(INFO) << "[HeartbeatAdded]: received heartbeat from unknown client id " @@ -1286,8 +1286,7 @@ void NodeManager::AssignTask(Task &task) { auto acquired_resources = local_available_resources_.Acquire(spec.GetRequiredResources()); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); - RAY_CHECK( - this->cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources())); + RAY_CHECK(cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources())); if (spec.IsActorCreationTask()) { // Check that we are not placing an actor creation task on a node with 0 CPUs. @@ -1394,8 +1393,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) { local_available_resources_.Release(worker.GetTaskResourceIds()); worker.ResetTaskResourceIds(); - RAY_CHECK(this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] - .Release(task.GetTaskSpecification().GetRequiredResources())); + RAY_CHECK( + cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( + task.GetTaskSpecification().GetRequiredResources())); } // If the finished task was an actor task, mark the returned dummy object as diff --git a/test/actor_test.py b/test/actor_test.py index 13aeb78f7fc3..0ba8701f8230 100644 --- a/test/actor_test.py +++ b/test/actor_test.py @@ -1232,7 +1232,7 @@ def blocking_method(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_exception_raised_when_actor_node_dies(shutdown_only): ray.worker._init(start_ray_local=True, num_local_schedulers=2, num_cpus=1) @@ -1279,7 +1279,7 @@ def inc(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1329,7 +1329,7 @@ def inc(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1466,7 +1466,7 @@ def __ray_restore__(self, checkpoint): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1496,7 +1496,7 @@ def test_checkpointing(shutdown_only): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1527,7 +1527,7 @@ def test_remote_checkpoint(shutdown_only): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1558,7 +1558,7 @@ def test_lost_checkpoint(shutdown_only): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1591,7 +1591,7 @@ def test_checkpoint_exception(shutdown_only): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1663,7 +1663,7 @@ def fork_many_incs(counter, num_incs): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1822,7 +1822,7 @@ def enqueue(queue, items): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -1860,7 +1860,7 @@ def read(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_fork(setup_queue_actor): queue = setup_queue_actor @@ -1879,7 +1879,7 @@ def fork(queue, key, item): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_fork_consistency(setup_queue_actor): queue = setup_queue_actor diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 3a57452e6115..2957ef626cae 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -36,7 +36,7 @@ def shutdown_only(): # This test checks that when a worker dies in the middle of a get, the plasma # store and raylet will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -89,7 +89,7 @@ def f(id_in_a_list): # This test checks that when a driver dies in the middle of a get, the plasma # store and raylet will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -107,8 +107,8 @@ def sleep_forever(): driver = """ import ray ray.init("{}") -ray.get(ray.ObjectID({})) -""".format(address_info["redis_address"], x_id.id()) +ray.get(ray.ObjectID(ray.utils.hex_to_binary("{}"))) +""".format(address_info["redis_address"], x_id.hex()) p = run_string_as_driver_nonblocking(driver) # Make sure the driver is running. @@ -135,7 +135,7 @@ def sleep_forever(): # This test checks that when a worker dies in the middle of a get, the # plasma store and manager will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY", False), + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -171,7 +171,7 @@ def f(): # This test checks that when a worker dies in the middle of a wait, the plasma # store and raylet will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -216,7 +216,7 @@ def block_in_wait(object_id_in_list): # This test checks that when a driver dies in the middle of a wait, the plasma # store and raylet will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -234,8 +234,8 @@ def sleep_forever(): driver = """ import ray ray.init("{}") -ray.wait([ray.ObjectID({})]) -""".format(address_info["redis_address"], x_id.id()) +ray.wait([ray.ObjectID(ray.utils.hex_to_binary("{}"))]) +""".format(address_info["redis_address"], x_id.hex()) p = run_string_as_driver_nonblocking(driver) # Make sure the driver is running. @@ -262,7 +262,7 @@ def sleep_forever(): # This test checks that when a worker dies in the middle of a wait, the # plasma store and manager will not die. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY", False), + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -342,7 +342,7 @@ def f(x): def _test_component_failed(component_type): """Kill a component on all worker nodes and check workload succeeds.""" # Raylet is able to pass a harder failure test than legacy ray. - use_raylet = os.environ.get("RAY_USE_XRAY") == "1" + use_raylet = os.environ.get("RAY_USE_XRAY") != "0" # Start with 4 workers and 4 cores. num_local_schedulers = 4 @@ -452,7 +452,7 @@ def check_components_alive(component_type, check_component_alive): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only makes sense with xray.") def test_raylet_failed(): # Kill all local schedulers on worker nodes. @@ -466,7 +466,7 @@ def test_raylet_failed(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not make sense with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", @@ -485,7 +485,7 @@ def test_local_scheduler_failed(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not make sense with xray.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", diff --git a/test/failure_test.py b/test/failure_test.py index 9cd25962b76b..2d01896b3c86 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -395,7 +395,7 @@ def ray_start_object_store_memory(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_put_error1(ray_start_object_store_memory): num_objects = 3 @@ -439,7 +439,7 @@ def put_arg_task(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_put_error2(ray_start_object_store_memory): # This is the same as the previous test, but it calls ray.put directly. @@ -495,7 +495,7 @@ def test_version_mismatch(shutdown_only): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_warning_monitor_died(shutdown_only): ray.init(num_cpus=0) @@ -539,7 +539,7 @@ def __init__(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_warning_for_infeasible_tasks(ray_start_regular): # Check that we get warning messages for infeasible tasks. @@ -592,7 +592,7 @@ def ray_start_two_nodes(): # Note that this test will take at least 10 seconds because it must wait for # the monitor to detect enough missed heartbeats. @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_warning_for_dead_node(ray_start_two_nodes): # Wait for the raylet to appear in the client table. diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 821fefa7bdfa..d83765e97c08 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -11,222 +11,222 @@ ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) DOCKER_SHA=$($ROOT_DIR/../../build-docker.sh --output-sha --no-cache) echo "Using Docker image" $DOCKER_SHA -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v0 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-ram-v4 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v0 \ --run A2C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "model": {"free_log_std": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"simple_optimizer": false, "num_sgd_iter": 2, "model": {"use_lstm": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"simple_optimizer": true, "num_sgd_iter": 2, "model": {"use_lstm": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "train_batch_size": 100, "num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-v0 \ --run ES \ --stop '{"training_iteration": 2}' \ --config '{"stepsize": 0.01, "episodes_per_batch": 20, "train_batch_size": 100, "num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run A3C \ --stop '{"training_iteration": 2}' \ -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-3, "schedule_max_timesteps": 100000, "exploration_fraction": 0.1, "exploration_final_eps": 0.02, "dueling": false, "hiddens": [], "model": {"fcnet_hiddens": [64], "fcnet_activation": "relu"}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run APEX \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false, "min_iter_time_s": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run DQN \ --stop '{"training_iteration": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"num_sgd_iter": 10, "sgd_minibatch_size": 64, "train_batch_size": 1000, "num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"lr": 1e-4, "schedule_max_timesteps": 2000000, "buffer_size": 10000, "exploration_fraction": 0.1, "exploration_final_eps": 0.01, "sample_batch_size": 4, "learning_starts": 10000, "target_network_update_freq": 1000, "gamma": 0.99, "prioritized_replay": true}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MontezumaRevenge-v0 \ --run PPO \ --stop '{"training_iteration": 2}' \ --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "model": {"dim": 40, "conv_filters": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run A3C \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "model": {"use_lstm": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run DQN \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1, "model": {"use_lstm": true, "max_seq_len": 100}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1, "num_envs_per_worker": 10}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pong-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env FrozenLake-v0 \ --run PG \ --stop '{"training_iteration": 2}' \ --config '{"sample_batch_size": 500, "num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "model": {"use_lstm": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v0 \ --run IMPALA \ --stop '{"training_iteration": 2}' \ --config '{"num_gpus": 0, "num_workers": 2, "min_iter_time_s": 1, "num_parallel_data_loaders": 2, "replay_proportion": 1.0, "model": {"use_lstm": true}}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env MountainCarContinuous-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ rllib train \ --env MountainCarContinuous-v0 \ --run DDPG \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env Pendulum-v0 \ --run APEX_DDPG \ @@ -234,97 +234,97 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --stop '{"training_iteration": 2}' \ --config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}' -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_local.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_checkpoint_restore.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_policy_evaluator.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_nested_spaces.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_serving_env.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_lstm.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_multi_agent_env.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_supported_spaces.py -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/pbt_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperband_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/async_hyperband_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_ray_hyperband.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_async_hyperband.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/hyperopt_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/tune/examples/genetic_example.py \ --smoke-test -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=PPO --stop=200 -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --run=IMPALA --stop=100 -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/cartpole_lstm.py --stop=200 --use-prev-action-reward -docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 # No Xray for PyTorch -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run -e RAY_USE_XRAY=0 --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env PongDeterministic-v4 \ --run A3C \ @@ -332,7 +332,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ --config '{"num_workers": 2, "use_pytorch": true, "model": {"use_lstm": false, "grayscale": true, "zero_mean": false, "dim": 84, "channel_major": true}, "preprocessor_pref": "rllib"}' # No Xray for PyTorch -docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ +docker run -e RAY_USE_XRAY=0 --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/train.py \ --env CartPole-v1 \ --run A3C \ diff --git a/test/multi_node_test.py b/test/multi_node_test.py index a1f0bd87be29..1cc783529f88 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -202,7 +202,7 @@ def ray_start_head_with_resources(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_drivers_release_resources(ray_start_head_with_resources): redis_address = ray_start_head_with_resources diff --git a/test/runtest.py b/test/runtest.py index a9efb2fb3925..c0a5dbc54bc9 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -990,7 +990,7 @@ def get_path2(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray (nor is it intended to).") def test_logging_api(shutdown_only): ray.init(num_cpus=1) @@ -1038,7 +1038,7 @@ def test_log_span_exception(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_profiling_api(shutdown_only): ray.init(num_cpus=2) @@ -1198,7 +1198,7 @@ def test_multi_threading_in_worker(): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_free_objects_multi_node(shutdown_only): ray.worker._init( @@ -1639,7 +1639,7 @@ def test(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_zero_cpus(shutdown_only): ray.init(num_cpus=0) @@ -1669,7 +1669,7 @@ def method(self): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") != "1", + os.environ.get("RAY_USE_XRAY") == "0", reason="This test only works with xray.") def test_fractional_resources(shutdown_only): ray.init(num_cpus=6, num_gpus=3, resources={"Custom": 1}) @@ -2043,7 +2043,7 @@ def h(i): object_ids = [f.remote(i, j) for j in range(2)] return ray.wait(object_ids, num_returns=len(object_ids)) - if os.environ.get("RAY_USE_XRAY") == "1": + if os.environ.get("RAY_USE_XRAY") != "0": ray.get([h.remote(i) for i in range(4)]) @ray.remote @@ -2350,7 +2350,7 @@ def f(): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="New GCS API doesn't have a Python API yet.") @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray (nor is it intended to).") def test_task_profile_api(shutdown_only): ray.init(num_cpus=1, redirect_output=True) @@ -2419,7 +2419,7 @@ def f(): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="New GCS API doesn't have a Python API yet.") @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_dump_trace_file(shutdown_only): ray.init(num_cpus=1, redirect_output=True) @@ -2463,7 +2463,7 @@ def method(self): os.environ.get("RAY_USE_NEW_GCS") == "on", reason="New GCS API doesn't have a Python API yet.") @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") def test_flush_api(shutdown_only): ray.init(num_cpus=1) diff --git a/test/stress_tests.py b/test/stress_tests.py index 6fc7cc487e58..7c37d9091a4a 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -162,7 +162,7 @@ def ray_start_reconstruction(request): # Start the Redis global state store. node_ip_address = "127.0.0.1" - use_raylet = os.environ.get("RAY_USE_XRAY") == "1" + use_raylet = os.environ.get("RAY_USE_XRAY") != "0" redis_address, redis_shards = ray.services.start_redis( node_ip_address, use_raylet=use_raylet) redis_ip_address = ray.services.get_ip_address(redis_address) @@ -186,7 +186,8 @@ def ray_start_reconstruction(request): store_stdout_file=store_stdout_file, store_stderr_file=store_stderr_file, manager_stdout_file=manager_stdout_file, - manager_stderr_file=manager_stderr_file)) + manager_stderr_file=manager_stderr_file, + use_raylet=use_raylet)) # Start the rest of the services in the Ray cluster. address_info = { @@ -401,7 +402,7 @@ def wait_for_errors(error_check): @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "1", + os.environ.get("RAY_USE_XRAY") != "0", reason="This test does not work with xray yet.") @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", From 7c1fd19fd951524a617a43044fcdf810bbb65e34 Mon Sep 17 00:00:00 2001 From: Hanwei Jin Date: Wed, 24 Oct 2018 20:43:39 +0800 Subject: [PATCH 053/215] [Java] support python worker command in raylet (#3092) ## What do these changes do? support raylet, which is started by java runManager, to start python default_worker.py . So when doing local test of java call python task, it helps auto start python worker. ## Related issue number --- .../org/ray/runtime/config/RayConfig.java | 7 +++++++ .../org/ray/runtime/runner/RunManager.java | 20 ++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index b172a41114e0..1688fa6db3dd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -54,6 +54,7 @@ public class RayConfig { public final String plasmaStoreExecutablePath; public final String rayletExecutablePath; public final String driverResourcePath; + public final String pythonWorkerCommand; private void validate() { if (workerMode == WorkerMode.WORKER) { @@ -136,6 +137,12 @@ public RayConfig(Config config) { jvmParameters = ImmutableList.of(); } + if (config.hasPath("ray.worker.python-command")) { + pythonWorkerCommand = config.getString("ray.worker.python-command"); + } else { + pythonWorkerCommand = null; + } + // redis configurations String redisAddress = config.getString("ray.redis.address"); if (!redisAddress.isEmpty()) { diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 3be219dca75c..940d0e78b72f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -185,7 +185,7 @@ private void startRaylet() { "0", // number of initial workers String.valueOf(maximumStartupConcurrency), ResourceUtil.getResourcesStringFromMap(rayConfig.resources), - "", // python worker command + buildPythonWorkerCommand(), // python worker command buildWorkerCommandRaylet() // java worker command ); @@ -247,4 +247,22 @@ private void startObjectStore() { startProcess(command, null, "plasma_store"); } + private String buildPythonWorkerCommand() { + // disable python worker start from raylet, which starts from java + if (rayConfig.pythonWorkerCommand == null) { + return ""; + } + + List cmd = new ArrayList<>(); + cmd.add(rayConfig.pythonWorkerCommand); + cmd.add("--node-ip-address=" + rayConfig.nodeIp); + cmd.add("--object-store-name=" + rayConfig.objectStoreSocketName); + cmd.add("--raylet-name=" + rayConfig.rayletSocketName); + cmd.add("--redis-address=" + rayConfig.getRedisAddress()); + + String command = cmd.stream().collect(Collectors.joining(" ")); + LOGGER.debug("python worker command: {}", command); + return command; + } + } From 55d161b49f4ab2803463c3b003bc392a49e7faa6 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 24 Oct 2018 13:57:36 -0700 Subject: [PATCH 054/215] [autoscaler] Also grant roles to worker nodes --- python/ray/autoscaler/aws/config.py | 1 + python/ray/autoscaler/gcp/config.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 8e5d3a4daffc..79392e31cb05 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -101,6 +101,7 @@ def _configure_iam_role(config): logger.info("Role not specified for head node, using {}".format( profile.arn)) config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn} + config["worker_nodes"]["IamInstanceProfile"] = {"Arn": profile.arn} return config diff --git a/python/ray/autoscaler/gcp/config.py b/python/ray/autoscaler/gcp/config.py index d6ae2edeb008..a651c3983ac9 100644 --- a/python/ray/autoscaler/gcp/config.py +++ b/python/ray/autoscaler/gcp/config.py @@ -168,12 +168,16 @@ def _configure_iam_role(config): _add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES) + # NOTE: The amount of access is determined by the scope + IAM + # role of the service account. Even if the cloud-platform scope + # gives (scope) access to the whole cloud-platform, the service + # account is limited by the IAM rights specified below. config["head_node"]["serviceAccounts"] = [{ "email": service_account["email"], - # NOTE: The amount of access is determined by the scope + IAM - # role of the service account. Even if the cloud-platform scope - # gives (scope) access to the whole cloud-platform, the service - # account is limited by the IAM rights specified below. + "scopes": ["https://www.googleapis.com/auth/cloud-platform"] + }] + config["worker_nodes"]["serviceAccounts"] = [{ + "email": service_account["email"], "scopes": ["https://www.googleapis.com/auth/cloud-platform"] }] From 5aa29613dbe1f81ddbf1bcf6b6fd9e5226618829 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 24 Oct 2018 16:30:00 -0700 Subject: [PATCH 055/215] Fix linting errors. (#3127) --- .travis.yml | 2 +- python/ray/experimental/state.py | 2 +- python/ray/function_manager.py | 2 +- python/ray/rllib/test/test_supported_spaces.py | 2 +- python/ray/scripts/scripts.py | 2 +- python/ray/services.py | 2 +- python/ray/tune/suggest/hyperopt.py | 2 +- python/ray/worker.py | 2 +- python/ray/workers/default_worker.py | 2 +- python/setup.py | 2 +- test/jenkins_tests/multi_node_tests/many_drivers_test.py | 2 +- test/jenkins_tests/multi_node_tests/remove_driver_test.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3633198f75b6..5088b62d92a7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,7 +53,7 @@ matrix: - sphinx-build -W -b html -d _build/doctrees source _build/html - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504 + - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 - .travis/format.sh --all - os: linux diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 99c3e3ba2202..9f1215a7e988 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -908,7 +908,7 @@ def dump_catapult_trace(self, repr(arg) for arg in task_table[task_id]["TaskSpec"]["Args"] ] - except Exception as e: + except Exception: print("Could not find task {}".format(task_id)) # filter out tasks not in task_table diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index f9cb1a08b735..72ec53651df7 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -163,7 +163,7 @@ def f(): try: function = pickle.loads(serialized_function) - except Exception as e: + except Exception: # If an exception was thrown when the remote function was imported, # we record the traceback and notify the scheduler of the failure. traceback_str = format_error_message(traceback.format_exc()) diff --git a/python/ray/rllib/test/test_supported_spaces.py b/python/ray/rllib/test/test_supported_spaces.py index 4f1aee0120e9..f40145c34604 100644 --- a/python/ray/rllib/test/test_supported_spaces.py +++ b/python/ray/rllib/test/test_supported_spaces.py @@ -70,7 +70,7 @@ def check_support(alg, config, stats, check_bounds=False): try: a = get_agent_class(alg)(config=config, env="stub_env") a.train() - except UnsupportedSpaceException as e: + except UnsupportedSpaceException: stat = "unsupported" except Exception as e: stat = "ERROR" diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index c4fd13b67cf2..654fba7208d7 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -225,7 +225,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, try: resources = json.loads(resources) - except Exception as e: + except Exception: raise Exception("Unable to parse the --resources argument using " "json.loads. Try using a format like\n\n" " --resources='{\"CustomResource1\": 3, " diff --git a/python/ray/services.py b/python/ray/services.py index c78e7556b990..d57bc04291c1 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -335,7 +335,7 @@ def wait_for_redis_to_start(redis_ip_address, "Waiting for redis server at {}:{} to respond...".format( redis_ip_address, redis_port)) redis_client.client_list() - except redis.ConnectionError as e: + except redis.ConnectionError: # Wait a little bit. time.sleep(1) logger.info("Failed to connect to the redis server, retrying.") diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 9173b56cc372..2c1c1317616d 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -10,7 +10,7 @@ hyperopt_logger = logging.getLogger("hyperopt") hyperopt_logger.setLevel(logging.WARNING) import hyperopt as hpo -except Exception as e: +except Exception: hpo = None from ray.tune.error import TuneError diff --git a/python/ray/worker.py b/python/ray/worker.py index bce109cd8a34..266b995f6130 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -2387,7 +2387,7 @@ def register_custom_serializer(cls, # worker. However, determinism is not guaranteed, and the # result may be different on different workers. class_id = _try_to_compute_deterministic_class_id(cls) - except Exception as e: + except Exception: raise serialization.CloudPickleError("Failed to pickle class " "'{}'".format(cls)) else: diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 670ee092d0e5..7fe46218f653 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -106,7 +106,7 @@ # main_loop. If an exception is thrown here, then that means that # there is some error that we didn't anticipate. ray.worker.global_worker.main_loop() - except Exception as e: + except Exception: traceback_str = traceback.format_exc() + error_explanation ray.utils.push_error_to_driver( ray.worker.global_worker, diff --git a/python/setup.py b/python/setup.py index 70d7cd87fadb..29e296a13d90 100644 --- a/python/setup.py +++ b/python/setup.py @@ -98,7 +98,7 @@ def run(self): for filename in optional_ray_files: try: self.move_file(filename) - except Exception as e: + except Exception: print("Failed to copy optional file {}. This is ok." .format(filename)) diff --git a/test/jenkins_tests/multi_node_tests/many_drivers_test.py b/test/jenkins_tests/multi_node_tests/many_drivers_test.py index d00e84a58c0f..94eeb4715e66 100644 --- a/test/jenkins_tests/multi_node_tests/many_drivers_test.py +++ b/test/jenkins_tests/multi_node_tests/many_drivers_test.py @@ -50,7 +50,7 @@ def try_to_create_actor(actor_class, timeout=500): while time.time() - start_time < timeout: try: actor = actor_class.remote() - except Exception as e: + except Exception: time.sleep(0.1) else: return actor diff --git a/test/jenkins_tests/multi_node_tests/remove_driver_test.py b/test/jenkins_tests/multi_node_tests/remove_driver_test.py index 4b61634b3069..08a100670997 100644 --- a/test/jenkins_tests/multi_node_tests/remove_driver_test.py +++ b/test/jenkins_tests/multi_node_tests/remove_driver_test.py @@ -199,7 +199,7 @@ def try_to_create_actor(actor_class, driver_index, actor_index, try: actor = actor_class.remote(driver_index, actor_index, redis_address) - except Exception as e: + except Exception: time.sleep(0.1) else: return actor From d34516f1f8aeee4a0b85dbe1543cca7c1acac6ba Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 25 Oct 2018 21:43:08 -0700 Subject: [PATCH 056/215] Update Gemfile Jekyll version (#3140) --- site/Gemfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/site/Gemfile b/site/Gemfile index 8af267397b31..9ae4bf67ff67 100644 --- a/site/Gemfile +++ b/site/Gemfile @@ -9,7 +9,7 @@ ruby RUBY_VERSION # # This will help ensure the proper Jekyll version is running. # Happy Jekylling! -gem "jekyll", "3.4.3" +gem "jekyll", ">= 3.6.3" # This is the default theme for new Jekyll sites. You may change this to anything you like. gem "minima", "~> 2.0" From d3148cc3ab7834f86f528a92977eda2b776f3760 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 25 Oct 2018 22:18:10 -0700 Subject: [PATCH 057/215] [SGD] Provide better error message if model graphs have different numbers of variables (#3139) --- python/ray/experimental/sgd/modified_allreduce.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/ray/experimental/sgd/modified_allreduce.py b/python/ray/experimental/sgd/modified_allreduce.py index a9d6879f99c7..7c446aa974e1 100644 --- a/python/ray/experimental/sgd/modified_allreduce.py +++ b/python/ray/experimental/sgd/modified_allreduce.py @@ -584,7 +584,15 @@ def end_interval(indices, small_ranges, large_indices): if len(small_ranges): new_tower_grads = [] for dev_idx, gv_list in enumerate(tower_grads): - assert len(gv_list) == num_gv + assert len(gv_list) == num_gv, ( + "Possible cause: " + "Networks constructed on different workers " + "don't have the same number of variables. " + "If you use tf.GraphKeys or tf.global_variables() " + "with multiple graphs per worker during network " + "construction, you need to use " + "appropriate scopes, see " + "https://github.com/ray-project/ray/issues/3136") new_gv_list = [] for r in small_ranges: key = '%d:%d' % (dev_idx, len(new_gv_list)) From b4614ae69ab530c28f48f0ff909a815ef988aaca Mon Sep 17 00:00:00 2001 From: bibabolynn <1018527906@qq.com> Date: Fri, 26 Oct 2018 13:36:34 +0800 Subject: [PATCH 058/215] [java] customize path of ray.conf (#3100) users can add custom path of ray.config by using -Dray.config=/path/to/ray.conf --- java/README.rst | 1 + .../java/org/ray/runtime/config/RayConfig.java | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/java/README.rst b/java/README.rst index 95ab961e769d..e01616935787 100644 --- a/java/README.rst +++ b/java/README.rst @@ -7,6 +7,7 @@ Ray will read your configurations in the following order: * Java system properties: e.g., ``-Dray.home=/path/to/ray``. * A ``ray.conf`` file in the classpath: `example `_. +* Customise your own ``ray.conf`` path using system property ``-Dray.config=/path/to/ray.conf`` For all available config items and default values, see `this file `_. diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 1688fa6db3dd..e07b9e89e9e7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -7,11 +7,13 @@ import com.typesafe.config.ConfigException; import com.typesafe.config.ConfigFactory; +import java.io.File; import java.util.List; import java.util.Map; import org.ray.api.id.UniqueId; import org.ray.runtime.util.NetworkUtil; import org.ray.runtime.util.ResourceUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -242,9 +244,16 @@ public String toString() { */ public static RayConfig create() { ConfigFactory.invalidateCaches(); - Config config = ConfigFactory.systemProperties() - .withFallback(ConfigFactory.load(CUSTOM_CONFIG_FILE)) - .withFallback(ConfigFactory.load(DEFAULT_CONFIG_FILE)); + Config config = ConfigFactory.systemProperties(); + String configPath = System.getProperty("ray.config"); + if (StringUtil.isNullOrEmpty(configPath)) { + LOGGER.info("Loading config from \"ray.conf\" file in classpath."); + config = config.withFallback(ConfigFactory.load(CUSTOM_CONFIG_FILE)); + } else { + LOGGER.info("Loading config from " + configPath + "."); + config = config.withFallback(ConfigFactory.parseFile(new File(configPath))); + } + config = config.withFallback(ConfigFactory.load(DEFAULT_CONFIG_FILE)); return new RayConfig(config); } From 055daf17a097ad9087021399383dd5181dfd1371 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 26 Oct 2018 12:42:11 -0700 Subject: [PATCH 059/215] [autoscaler] better message if there are more than 10 key pairs --- python/ray/autoscaler/aws/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 79392e31cb05..54c07aa29e98 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -115,7 +115,8 @@ def _configure_key_pair(config): ec2 = _resource("ec2", config) # Try a few times to get or create a good key pair. - for i in range(10): + MAX_NUM_KEYS = 20 + for i in range(MAX_NUM_KEYS): key_name, key_path = key_pair(i, config["provider"]["region"]) key = _get_key(key_name, config) @@ -132,7 +133,12 @@ def _configure_key_pair(config): os.chmod(key_path, 0o600) break - assert key, "AWS keypair {} not found for {}".format(key_name, key_path) + if not key: + raise ValueError( + "No matching local key file for any of the key pairs in this " + "account with ids from 0..{}. ".format(key_name) + + "Consider deleting some unused keys pairs from your account.") + assert os.path.exists(key_path), \ "Private key file {} not found for {}".format(key_path, key_name) From 658c14282c384aa49a52b7d7bb4ac71c65cf8b86 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 26 Oct 2018 13:36:58 -0700 Subject: [PATCH 060/215] Remove legacy Ray code. (#3121) * Remove legacy Ray code. * Fix cmake and simplify monitor. * Fix linting * Updates * Fix * Implement some methods. * Remove more plasma manager references. * Fix * Linting * Fix * Fix * Make sure class IDs are strings. * Some path fixes * Fix * Path fixes and update arrow * Fixes. * linting * Fixes * Java fixes * Some java fixes * TaskLanguage -> Language * Minor * Fix python test and remove unused method signature. * Fix java tests * Fix jenkins tests * Remove commented out code. --- .clang-format | 5 +- .gitignore | 15 - .travis.yml | 86 +- .travis/format.sh | 1 - CMakeLists.txt | 19 +- cmake/Modules/ArrowExternalProject.cmake | 4 +- cmake/Modules/Common.cmake | 3 - doc/source/actors.rst | 3 +- doc/source/conf.py | 101 +- doc/source/fault-tolerance.rst | 2 +- doc/source/internals-overview.rst | 6 +- doc/source/tempfile.rst | 2 - doc/source/tutorial.rst | 4 +- doc/source/using-ray-on-a-cluster.rst | 1 - doc/source/using-ray-on-a-large-cluster.rst | 1 - java/checkstyle-suppressions.xml | 2 +- java/prepare.sh | 6 +- .../org/ray/runtime/config/RayConfig.java | 6 +- .../{TaskLanguage.java => Language.java} | 10 +- .../ray/runtime/raylet/RayletClientImpl.java | 8 +- .../java/org/ray/api/test/RayConfigTest.java | 2 +- java/tutorial/pom.xml | 2 +- python/ray/__init__.py | 2 +- python/ray/actor.py | 2 +- python/ray/common/redis_module/.gitkeep | 0 python/ray/common/redis_module/runtest.py | 451 -- python/ray/common/test/test.py | 181 - .../ray/common/thirdparty/redis/src/.gitkeep | 0 .../ray/{common => core/src/ray}/__init__.py | 0 .../raylet}/__init__.py | 0 python/ray/experimental/sgd/sgd.py | 6 +- python/ray/experimental/sgd/util.py | 13 +- python/ray/experimental/state.py | 540 +- python/ray/gcs_utils.py | 29 - python/ray/global_scheduler/__init__.py | 7 - python/ray/global_scheduler/build/.gitkeep | 0 .../global_scheduler_services.py | 61 - python/ray/global_scheduler/test/test.py | 332 - python/ray/internal/internal_api.py | 7 +- python/ray/local_scheduler/build/.gitkeep | 0 .../local_scheduler_services.py | 132 - python/ray/local_scheduler/test/test.py | 206 - python/ray/monitor.py | 457 +- python/ray/plasma/__init__.py | 7 +- python/ray/plasma/plasma.py | 103 +- python/ray/plasma/test/test.py | 560 -- python/ray/plasma/utils.py | 53 - python/ray/profiling.py | 91 +- python/ray/ray_constants.py | 3 +- .../{local_scheduler => raylet}/__init__.py | 3 +- python/ray/rllib/utils/actors.py | 7 +- python/ray/scripts/scripts.py | 38 +- python/ray/services.py | 349 +- python/ray/tempfile_services.py | 60 - python/ray/test/cluster_utils.py | 18 +- python/ray/test/test_ray_init.py | 3 - python/ray/test/test_utils.py | 17 +- python/ray/tune/ray_trial_executor.py | 17 +- python/ray/tune/test/trial_runner_test.py | 12 +- python/ray/tune/util.py | 2 +- python/ray/utils.py | 43 +- python/ray/worker.py | 395 +- python/ray/workers/default_worker.py | 5 +- python/setup.py | 9 +- src/common/CMakeLists.txt | 131 - src/common/common.cc | 20 - src/common/common.h | 75 - src/common/doc/tasks.md | 32 - src/common/event_loop.cc | 63 - src/common/event_loop.h | 103 - src/common/format/common.fbs | 203 - src/common/io.cc | 416 -- src/common/io.h | 228 - src/common/logging.cc | 107 - src/common/logging.h | 58 - src/common/net.cc | 24 - src/common/net.h | 9 - src/common/redis_module/ray_redis_module.cc | 1886 ----- src/common/shims/windows/getopt.c | 69 - src/common/shims/windows/getopt.h | 4 - src/common/shims/windows/msg.c | 208 - src/common/shims/windows/netdb.h | 4 - src/common/shims/windows/netinet/in.h | 4 - src/common/shims/windows/poll.h | 4 - src/common/shims/windows/socketpair.c | 150 - src/common/shims/windows/strings.h | 4 - src/common/shims/windows/sys/ioctl.h | 4 - src/common/shims/windows/sys/mman.h | 36 - src/common/shims/windows/sys/select.h | 4 - src/common/shims/windows/sys/socket.h | 36 - src/common/shims/windows/sys/time.h | 12 - src/common/shims/windows/sys/un.h | 13 - src/common/shims/windows/sys/wait.h | 4 - src/common/shims/windows/unistd.h | 11 - src/common/state/actor_notification_table.cc | 47 - src/common/state/actor_notification_table.h | 74 - src/common/state/db.h | 70 - src/common/state/db_client_table.cc | 90 - src/common/state/db_client_table.h | 120 - src/common/state/driver_table.cc | 23 - src/common/state/driver_table.h | 50 - src/common/state/error_table.cc | 24 - src/common/state/error_table.h | 50 - src/common/state/local_scheduler_table.cc | 48 - src/common/state/local_scheduler_table.h | 98 - src/common/state/object_table.cc | 119 - src/common/state/object_table.h | 242 - src/common/state/redis.cc | 1692 ----- src/common/state/redis.h | 356 - src/common/state/table.cc | 200 - src/common/state/table.h | 216 - src/common/state/task_table.cc | 80 - src/common/state/task_table.h | 190 - src/common/task.cc | 606 -- src/common/task.h | 609 -- src/common/test/db_tests.cc | 246 - src/common/test/example_task.h | 77 - src/common/test/io_tests.cc | 114 - src/common/test/object_table_tests.cc | 919 --- src/common/test/redis_tests.cc | 238 - src/common/test/run_tests.sh | 43 - src/common/test/run_valgrind.sh | 27 - src/common/test/task_table_tests.cc | 460 -- src/common/test/task_tests.cc | 212 - src/common/test/test_common.h | 91 - src/common/thirdparty/download_thirdparty.bat | 15 - src/common/thirdparty/greatest.h | 1023 --- src/common/thirdparty/patches/.gitattributes | 1 - .../patches/windows/python-pyconfig.patch | 25 - .../thirdparty/patches/windows/redis.patch | 772 -- src/global_scheduler/CMakeLists.txt | 14 - src/global_scheduler/global_scheduler.cc | 492 -- src/global_scheduler/global_scheduler.h | 94 - .../global_scheduler_algorithm.cc | 257 - .../global_scheduler_algorithm.h | 126 - src/local_scheduler/CMakeLists.txt | 104 - src/local_scheduler/build/.gitkeep | 0 .../format/local_scheduler.fbs | 130 - src/local_scheduler/local_scheduler.cc | 1555 ---- src/local_scheduler/local_scheduler.h | 176 - .../local_scheduler_algorithm.cc | 1851 ----- .../local_scheduler_algorithm.h | 438 -- src/local_scheduler/local_scheduler_client.cc | 385 - src/local_scheduler/local_scheduler_client.h | 260 - src/local_scheduler/local_scheduler_shared.h | 137 - .../test/local_scheduler_tests.cc | 704 -- src/local_scheduler/test/run_tests.sh | 38 - src/local_scheduler/test/run_valgrind.sh | 41 - src/plasma/CMakeLists.txt | 61 - src/plasma/doc/plasma-doxy-config | 2473 ------- src/plasma/plasma_manager.cc | 1692 ----- src/plasma/plasma_manager.h | 277 - src/plasma/plasma_protocol.cc | 576 -- src/plasma/protocol.h | 77 - src/plasma/setup-env.sh | 5 - src/plasma/setup.py | 35 - src/plasma/test/client_tests.cc | 337 - src/plasma/test/manager_tests.cc | 313 - src/plasma/test/run_tests.sh | 59 - src/plasma/test/run_valgrind.sh | 11 - src/plasma/thirdparty/ae/ae.c | 465 -- src/plasma/thirdparty/ae/ae.h | 123 - src/plasma/thirdparty/ae/ae_epoll.c | 135 - src/plasma/thirdparty/ae/ae_evport.c | 320 - src/plasma/thirdparty/ae/ae_kqueue.c | 138 - src/plasma/thirdparty/ae/ae_select.c | 106 - src/plasma/thirdparty/ae/config.h | 54 - src/plasma/thirdparty/ae/zmalloc.h | 16 - src/plasma/thirdparty/dlmalloc.c | 6281 ----------------- src/plasma/thirdparty/xxhash.c | 889 --- src/plasma/thirdparty/xxhash.h | 293 - src/ray/.clang-format | 5 - src/ray/CMakeLists.txt | 40 +- src/ray/common/client_connection.cc | 2 +- src/{ => ray}/common/common_protocol.cc | 22 +- src/{ => ray}/common/common_protocol.h | 26 +- src/ray/gcs/CMakeLists.txt | 5 +- src/ray/gcs/asio.h | 4 +- src/ray/gcs/asio_test.cc | 5 + src/ray/gcs/client.cc | 4 +- src/ray/gcs/client.h | 2 - src/ray/gcs/client_test.cc | 154 +- src/ray/gcs/format/gcs.fbs | 63 + src/ray/gcs/redis_context.cc | 8 +- .../gcs}/redis_module/CMakeLists.txt | 0 .../gcs}/redis_module/chain_module.h | 14 +- src/ray/gcs/redis_module/ray_redis_module.cc | 671 ++ .../gcs}/redis_module/redis_string.h | 4 +- .../gcs}/redis_module/redismodule.h | 0 src/ray/gcs/tables.cc | 2 +- src/ray/gcs/tables.h | 90 - src/ray/gcs/task_table.cc | 71 - src/ray/id.cc | 8 +- src/ray/id.h | 4 - .../object_manager/format/object_manager.fbs | 24 + src/ray/object_manager/object_directory.cc | 6 +- src/ray/object_manager/object_directory.h | 12 +- src/ray/object_manager/object_manager.cc | 18 +- src/ray/object_manager/object_manager.h | 7 +- .../object_manager_client_connection.h | 2 +- .../object_store_notification_manager.cc | 13 +- .../object_store_notification_manager.h | 8 +- .../test/object_manager_stress_test.cc | 6 +- .../test/object_manager_test.cc | 4 +- src/{common/state => ray}/ray_config.h | 28 +- src/ray/raylet/CMakeLists.txt | 54 + src/ray/raylet/format/node_manager.fbs | 6 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 128 +- .../org_ray_runtime_raylet_RayletClientImpl.h | 77 +- .../raylet}/lib/python/common_extension.cc | 541 +- .../raylet}/lib/python/common_extension.h | 8 +- .../raylet}/lib/python/config_extension.cc | 114 +- .../raylet}/lib/python/config_extension.h | 8 +- .../lib/python/local_scheduler_extension.cc | 313 +- src/ray/raylet/lineage_cache.h | 2 +- src/ray/raylet/local_scheduler_client.cc | 400 ++ src/ray/raylet/local_scheduler_client.h | 180 + src/ray/raylet/main.cc | 2 +- src/ray/raylet/monitor.cc | 1 + src/ray/raylet/node_manager.cc | 52 +- .../raylet/object_manager_integration_test.cc | 4 +- src/ray/raylet/reconstruction_policy_test.cc | 3 +- src/ray/raylet/scheduling_resources.cc | 3 + src/ray/raylet/task_spec.cc | 57 +- src/ray/raylet/task_spec.h | 4 +- src/ray/raylet/worker.cc | 1 - src/ray/test/run_gcs_tests.sh | 8 +- src/ray/test/run_object_manager_tests.sh | 7 +- src/ray/test/run_object_manager_valgrind.sh | 7 +- src/{common => ray}/thirdparty/ae/ae.c | 0 src/{common => ray}/thirdparty/ae/ae.h | 0 src/{common => ray}/thirdparty/ae/ae_epoll.c | 0 src/{common => ray}/thirdparty/ae/ae_evport.c | 0 src/{common => ray}/thirdparty/ae/ae_kqueue.c | 0 src/{common => ray}/thirdparty/ae/ae_select.c | 0 src/{common => ray}/thirdparty/ae/config.h | 0 src/{common => ray}/thirdparty/ae/zmalloc.h | 0 .../thirdparty/hiredis/.gitignore | 0 .../thirdparty/hiredis/.travis.yml | 0 .../thirdparty/hiredis/CHANGELOG.md | 0 .../thirdparty/hiredis/COPYING | 0 .../thirdparty/hiredis/Makefile | 0 .../thirdparty/hiredis/README.md | 0 .../thirdparty/hiredis/adapters/ae.h | 0 .../thirdparty/hiredis/adapters/glib.h | 0 .../thirdparty/hiredis/adapters/ivykis.h | 0 .../thirdparty/hiredis/adapters/libev.h | 0 .../thirdparty/hiredis/adapters/libevent.h | 0 .../thirdparty/hiredis/adapters/libuv.h | 0 .../thirdparty/hiredis/adapters/macosx.h | 0 .../thirdparty/hiredis/adapters/qt.h | 0 .../thirdparty/hiredis/async.c | 0 .../thirdparty/hiredis/async.h | 0 src/{common => ray}/thirdparty/hiredis/dict.c | 0 src/{common => ray}/thirdparty/hiredis/dict.h | 0 .../thirdparty/hiredis/examples/example-ae.c | 0 .../hiredis/examples/example-glib.c | 0 .../hiredis/examples/example-ivykis.c | 0 .../hiredis/examples/example-libev.c | 0 .../hiredis/examples/example-libevent.c | 0 .../hiredis/examples/example-libuv.c | 0 .../hiredis/examples/example-macosx.c | 0 .../hiredis/examples/example-qt.cpp | 0 .../thirdparty/hiredis/examples/example-qt.h | 0 .../thirdparty/hiredis/examples/example.c | 0 .../thirdparty/hiredis/fmacros.h | 0 .../thirdparty/hiredis/hiredis.c | 0 .../thirdparty/hiredis/hiredis.h | 0 src/{common => ray}/thirdparty/hiredis/net.c | 0 src/{common => ray}/thirdparty/hiredis/net.h | 0 src/{common => ray}/thirdparty/hiredis/read.c | 0 src/{common => ray}/thirdparty/hiredis/read.h | 0 src/{common => ray}/thirdparty/hiredis/sds.c | 0 src/{common => ray}/thirdparty/hiredis/sds.h | 0 src/{common => ray}/thirdparty/hiredis/test.c | 0 .../thirdparty/hiredis/win32.h | 0 src/{common => ray}/thirdparty/sha256.c | 0 src/{common => ray}/thirdparty/sha256.h | 0 test/actor_test.py | 47 +- test/component_failures_test.py | 255 +- test/failure_test.py | 20 +- test/jenkins_tests/multi_node_docker_test.py | 23 +- test/jenkins_tests/run_multi_node_tests.sh | 9 +- test/multi_node_test.py | 3 - test/multi_node_test_2.py | 10 - test/runtest.py | 238 +- test/stress_tests.py | 15 +- test/tempfile_test.py | 19 +- test/xray_test.py | 2 +- 289 files changed, 2471 insertions(+), 40719 deletions(-) rename java/runtime/src/main/java/org/ray/runtime/generated/{TaskLanguage.java => Language.java} (51%) delete mode 100644 python/ray/common/redis_module/.gitkeep delete mode 100644 python/ray/common/redis_module/runtest.py delete mode 100644 python/ray/common/test/test.py delete mode 100644 python/ray/common/thirdparty/redis/src/.gitkeep rename python/ray/{common => core/src/ray}/__init__.py (100%) rename python/ray/core/src/{local_scheduler => ray/raylet}/__init__.py (100%) delete mode 100644 python/ray/global_scheduler/__init__.py delete mode 100644 python/ray/global_scheduler/build/.gitkeep delete mode 100644 python/ray/global_scheduler/global_scheduler_services.py delete mode 100644 python/ray/global_scheduler/test/test.py delete mode 100644 python/ray/local_scheduler/build/.gitkeep delete mode 100644 python/ray/local_scheduler/local_scheduler_services.py delete mode 100644 python/ray/local_scheduler/test/test.py delete mode 100644 python/ray/plasma/test/test.py delete mode 100644 python/ray/plasma/utils.py rename python/ray/{local_scheduler => raylet}/__init__.py (76%) delete mode 100644 src/common/CMakeLists.txt delete mode 100644 src/common/common.cc delete mode 100644 src/common/common.h delete mode 100644 src/common/doc/tasks.md delete mode 100644 src/common/event_loop.cc delete mode 100644 src/common/event_loop.h delete mode 100644 src/common/format/common.fbs delete mode 100644 src/common/io.cc delete mode 100644 src/common/io.h delete mode 100644 src/common/logging.cc delete mode 100644 src/common/logging.h delete mode 100644 src/common/net.cc delete mode 100644 src/common/net.h delete mode 100644 src/common/redis_module/ray_redis_module.cc delete mode 100644 src/common/shims/windows/getopt.c delete mode 100644 src/common/shims/windows/getopt.h delete mode 100644 src/common/shims/windows/msg.c delete mode 100644 src/common/shims/windows/netdb.h delete mode 100644 src/common/shims/windows/netinet/in.h delete mode 100644 src/common/shims/windows/poll.h delete mode 100644 src/common/shims/windows/socketpair.c delete mode 100644 src/common/shims/windows/strings.h delete mode 100644 src/common/shims/windows/sys/ioctl.h delete mode 100644 src/common/shims/windows/sys/mman.h delete mode 100644 src/common/shims/windows/sys/select.h delete mode 100644 src/common/shims/windows/sys/socket.h delete mode 100644 src/common/shims/windows/sys/time.h delete mode 100644 src/common/shims/windows/sys/un.h delete mode 100644 src/common/shims/windows/sys/wait.h delete mode 100644 src/common/shims/windows/unistd.h delete mode 100644 src/common/state/actor_notification_table.cc delete mode 100644 src/common/state/actor_notification_table.h delete mode 100644 src/common/state/db.h delete mode 100644 src/common/state/db_client_table.cc delete mode 100644 src/common/state/db_client_table.h delete mode 100644 src/common/state/driver_table.cc delete mode 100644 src/common/state/driver_table.h delete mode 100644 src/common/state/error_table.cc delete mode 100644 src/common/state/error_table.h delete mode 100644 src/common/state/local_scheduler_table.cc delete mode 100644 src/common/state/local_scheduler_table.h delete mode 100644 src/common/state/object_table.cc delete mode 100644 src/common/state/object_table.h delete mode 100644 src/common/state/redis.cc delete mode 100644 src/common/state/redis.h delete mode 100644 src/common/state/table.cc delete mode 100644 src/common/state/table.h delete mode 100644 src/common/state/task_table.cc delete mode 100644 src/common/state/task_table.h delete mode 100644 src/common/task.cc delete mode 100644 src/common/task.h delete mode 100644 src/common/test/db_tests.cc delete mode 100644 src/common/test/example_task.h delete mode 100644 src/common/test/io_tests.cc delete mode 100644 src/common/test/object_table_tests.cc delete mode 100644 src/common/test/redis_tests.cc delete mode 100644 src/common/test/run_tests.sh delete mode 100644 src/common/test/run_valgrind.sh delete mode 100644 src/common/test/task_table_tests.cc delete mode 100644 src/common/test/task_tests.cc delete mode 100644 src/common/test/test_common.h delete mode 100644 src/common/thirdparty/download_thirdparty.bat delete mode 100644 src/common/thirdparty/greatest.h delete mode 100644 src/common/thirdparty/patches/.gitattributes delete mode 100644 src/common/thirdparty/patches/windows/python-pyconfig.patch delete mode 100644 src/common/thirdparty/patches/windows/redis.patch delete mode 100644 src/global_scheduler/CMakeLists.txt delete mode 100644 src/global_scheduler/global_scheduler.cc delete mode 100644 src/global_scheduler/global_scheduler.h delete mode 100644 src/global_scheduler/global_scheduler_algorithm.cc delete mode 100644 src/global_scheduler/global_scheduler_algorithm.h delete mode 100644 src/local_scheduler/CMakeLists.txt delete mode 100644 src/local_scheduler/build/.gitkeep delete mode 100644 src/local_scheduler/format/local_scheduler.fbs delete mode 100644 src/local_scheduler/local_scheduler.cc delete mode 100644 src/local_scheduler/local_scheduler.h delete mode 100644 src/local_scheduler/local_scheduler_algorithm.cc delete mode 100644 src/local_scheduler/local_scheduler_algorithm.h delete mode 100644 src/local_scheduler/local_scheduler_client.cc delete mode 100644 src/local_scheduler/local_scheduler_client.h delete mode 100644 src/local_scheduler/local_scheduler_shared.h delete mode 100644 src/local_scheduler/test/local_scheduler_tests.cc delete mode 100644 src/local_scheduler/test/run_tests.sh delete mode 100644 src/local_scheduler/test/run_valgrind.sh delete mode 100644 src/plasma/CMakeLists.txt delete mode 100644 src/plasma/doc/plasma-doxy-config delete mode 100644 src/plasma/plasma_manager.cc delete mode 100644 src/plasma/plasma_manager.h delete mode 100644 src/plasma/plasma_protocol.cc delete mode 100644 src/plasma/protocol.h delete mode 100644 src/plasma/setup-env.sh delete mode 100644 src/plasma/setup.py delete mode 100644 src/plasma/test/client_tests.cc delete mode 100644 src/plasma/test/manager_tests.cc delete mode 100644 src/plasma/test/run_tests.sh delete mode 100644 src/plasma/test/run_valgrind.sh delete mode 100644 src/plasma/thirdparty/ae/ae.c delete mode 100644 src/plasma/thirdparty/ae/ae.h delete mode 100644 src/plasma/thirdparty/ae/ae_epoll.c delete mode 100644 src/plasma/thirdparty/ae/ae_evport.c delete mode 100644 src/plasma/thirdparty/ae/ae_kqueue.c delete mode 100644 src/plasma/thirdparty/ae/ae_select.c delete mode 100644 src/plasma/thirdparty/ae/config.h delete mode 100644 src/plasma/thirdparty/ae/zmalloc.h delete mode 100644 src/plasma/thirdparty/dlmalloc.c delete mode 100644 src/plasma/thirdparty/xxhash.c delete mode 100644 src/plasma/thirdparty/xxhash.h delete mode 100644 src/ray/.clang-format rename src/{ => ray}/common/common_protocol.cc (82%) rename src/{ => ray}/common/common_protocol.h (82%) rename src/{common => ray/gcs}/redis_module/CMakeLists.txt (100%) rename src/{common => ray/gcs}/redis_module/chain_module.h (84%) create mode 100644 src/ray/gcs/redis_module/ray_redis_module.cc rename src/{common => ray/gcs}/redis_module/redis_string.h (91%) rename src/{common => ray/gcs}/redis_module/redismodule.h (100%) delete mode 100644 src/ray/gcs/task_table.cc rename src/{common/state => ray}/ray_config.h (93%) rename src/{local_scheduler => ray/raylet}/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc (64%) rename src/{local_scheduler => ray/raylet}/lib/java/org_ray_runtime_raylet_RayletClientImpl.h (51%) rename src/{common => ray/raylet}/lib/python/common_extension.cc (53%) rename src/{common => ray/raylet}/lib/python/common_extension.h (86%) rename src/{common => ray/raylet}/lib/python/config_extension.cc (63%) rename src/{common => ray/raylet}/lib/python/config_extension.h (96%) rename src/{local_scheduler => ray/raylet}/lib/python/local_scheduler_extension.cc (54%) create mode 100644 src/ray/raylet/local_scheduler_client.cc create mode 100644 src/ray/raylet/local_scheduler_client.h rename src/{common => ray}/thirdparty/ae/ae.c (100%) rename src/{common => ray}/thirdparty/ae/ae.h (100%) rename src/{common => ray}/thirdparty/ae/ae_epoll.c (100%) rename src/{common => ray}/thirdparty/ae/ae_evport.c (100%) rename src/{common => ray}/thirdparty/ae/ae_kqueue.c (100%) rename src/{common => ray}/thirdparty/ae/ae_select.c (100%) rename src/{common => ray}/thirdparty/ae/config.h (100%) rename src/{common => ray}/thirdparty/ae/zmalloc.h (100%) rename src/{common => ray}/thirdparty/hiredis/.gitignore (100%) rename src/{common => ray}/thirdparty/hiredis/.travis.yml (100%) rename src/{common => ray}/thirdparty/hiredis/CHANGELOG.md (100%) rename src/{common => ray}/thirdparty/hiredis/COPYING (100%) rename src/{common => ray}/thirdparty/hiredis/Makefile (100%) rename src/{common => ray}/thirdparty/hiredis/README.md (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/ae.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/glib.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/ivykis.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/libev.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/libevent.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/libuv.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/macosx.h (100%) rename src/{common => ray}/thirdparty/hiredis/adapters/qt.h (100%) rename src/{common => ray}/thirdparty/hiredis/async.c (100%) rename src/{common => ray}/thirdparty/hiredis/async.h (100%) rename src/{common => ray}/thirdparty/hiredis/dict.c (100%) rename src/{common => ray}/thirdparty/hiredis/dict.h (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-ae.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-glib.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-ivykis.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-libev.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-libevent.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-libuv.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-macosx.c (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-qt.cpp (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example-qt.h (100%) rename src/{common => ray}/thirdparty/hiredis/examples/example.c (100%) rename src/{common => ray}/thirdparty/hiredis/fmacros.h (100%) rename src/{common => ray}/thirdparty/hiredis/hiredis.c (100%) rename src/{common => ray}/thirdparty/hiredis/hiredis.h (100%) rename src/{common => ray}/thirdparty/hiredis/net.c (100%) rename src/{common => ray}/thirdparty/hiredis/net.h (100%) rename src/{common => ray}/thirdparty/hiredis/read.c (100%) rename src/{common => ray}/thirdparty/hiredis/read.h (100%) rename src/{common => ray}/thirdparty/hiredis/sds.c (100%) rename src/{common => ray}/thirdparty/hiredis/sds.h (100%) rename src/{common => ray}/thirdparty/hiredis/test.c (100%) rename src/{common => ray}/thirdparty/hiredis/win32.h (100%) rename src/{common => ray}/thirdparty/sha256.c (100%) rename src/{common => ray}/thirdparty/sha256.h (100%) diff --git a/.clang-format b/.clang-format index c5ab0983b753..5c0f059e15f3 100644 --- a/.clang-format +++ b/.clang-format @@ -1,6 +1,5 @@ -BasedOnStyle: Chromium -ColumnLimit: 80 +BasedOnStyle: Google +ColumnLimit: 90 DerivePointerAlignment: false IndentCaseLabels: false PointerAlignment: Right -SpaceAfterCStyleCast: true diff --git a/.gitignore b/.gitignore index abd60923e631..f8130b3a2f85 100644 --- a/.gitignore +++ b/.gitignore @@ -4,22 +4,10 @@ /python/build /python/dist /python/flatbuffers-1.7.1/ -/src/common/thirdparty/redis -/src/thirdparty/arrow /flatbuffers-1.7.1/ -/src/thirdparty/boost/ -/src/thirdparty/boost_1_65_1/ -/src/thirdparty/boost_1_60_0/ -/src/thirdparty/catapult/ -/src/thirdparty/flatbuffers/ -/src/thirdparty/parquet-cpp /thirdparty/pkg/ # Files generated by flatc should be ignored -/src/common/format/*.py -/src/common/format/*_generated.h -/src/plasma/format/ -/src/local_scheduler/format/*_generated.h /src/ray/gcs/format/*_generated.h /src/ray/object_manager/format/*_generated.h /src/ray/raylet/format/*_generated.h @@ -54,9 +42,6 @@ python/.eggs *.dylib *.dll -# Cython-generated files -*.c - # Incremental linking files *.ilk diff --git a/.travis.yml b/.travis.yml index 5088b62d92a7..a2d1b106a65b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -53,7 +53,7 @@ matrix: - sphinx-build -W -b html -d _build/doctrees source _build/html - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - - flake8 --exclude=python/ray/core/src/common/flatbuffers_ep-prefix/,python/ray/core/generated/,src/common/format/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + - flake8 --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 - .travis/format.sh --all - os: linux @@ -69,16 +69,9 @@ matrix: script: - cd build - # - bash ../src/common/test/run_valgrind.sh - # - bash ../src/plasma/test/run_valgrind.sh - # - bash ../src/local_scheduler/test/run_valgrind.sh - bash ../src/ray/test/run_object_manager_valgrind.sh - cd .. - # - python ./python/ray/plasma/test/test.py valgrind - # - python ./python/ray/local_scheduler/test/test.py valgrind - # - python ./python/ray/global_scheduler/test/test.py valgrind - # Build Linux wheels. - os: linux dist: trusty @@ -108,75 +101,6 @@ matrix: - PYTHON=3.5 - RAY_USE_NEW_GCS=on - # Test legacy Ray. - - os: linux - dist: trusty - env: PYTHON=3.5 RAY_USE_XRAY=0 - install: - - ./.travis/install-dependencies.sh - - export PATH="$HOME/miniconda/bin:$PATH" - - ./.travis/install-ray.sh - - ./.travis/install-cython-examples.sh - - - cd build - - bash ../src/common/test/run_tests.sh - - bash ../src/plasma/test/run_tests.sh - - bash ../src/local_scheduler/test/run_tests.sh - - cd .. - - script: - - export PATH="$HOME/miniconda/bin:$PATH" - # The following is needed so cloudpickle can find some of the - # class definitions: The main module of tests that are run - # with pytest have the same name as the test file -- and this - # module is only found if the test directory is in the PYTHONPATH. - - export PYTHONPATH="$PYTHONPATH:./test/" - - - python -m pytest -v python/ray/common/test/test.py - - python -m pytest -v python/ray/common/redis_module/runtest.py - - python -m pytest -v python/ray/plasma/test/test.py - - python -m pytest -v python/ray/local_scheduler/test/test.py - - python -m pytest -v python/ray/global_scheduler/test/test.py - - - python -m pytest -v python/ray/test/test_global_state.py - - python -m pytest -v python/ray/test/test_queue.py - - python -m pytest -v python/ray/test/test_ray_init.py - - python -m pytest -v test/xray_test.py - - - python -m pytest -v test/runtest.py - - python -m pytest -v test/array_test.py - - python -m pytest -v test/actor_test.py - - python -m pytest -v test/autoscaler_test.py - - python -m pytest -v test/tensorflow_test.py - - python -m pytest -v test/failure_test.py - - python -m pytest -v test/microbenchmarks.py - - python -m pytest -v test/stress_tests.py - - pytest test/component_failures_test.py - - python test/multi_node_test.py - - python -m pytest -v test/multi_node_test_2.py - - python -m pytest -v test/recursion_test.py - - pytest test/monitor_test.py - - python -m pytest -v test/cython_test.py - - python -m pytest -v test/credis_test.py - - # ray tune tests - - python python/ray/tune/test/dependency_test.py - - python -m pytest -v python/ray/tune/test/trial_runner_test.py - - python -m pytest -v python/ray/tune/test/trial_scheduler_test.py - - python -m pytest -v python/ray/tune/test/experiment_test.py - - python -m pytest -v python/ray/tune/test/tune_server_test.py - - python -m pytest -v python/ray/tune/test/ray_trial_executor_test.py - - python -m pytest -v python/ray/tune/test/automl_searcher_test.py - - # ray rllib tests - - python -m pytest -v python/ray/rllib/test/test_catalog.py - - python -m pytest -v python/ray/rllib/test/test_filters.py - - python -m pytest -v python/ray/rllib/test/test_optimizers.py - - python -m pytest -v python/ray/rllib/test/test_evaluators.py - - # ray temp file tests - - python -m pytest -v test/tempfile_test.py - install: - ./.travis/install-dependencies.sh @@ -207,12 +131,6 @@ script: # module is only found if the test directory is in the PYTHONPATH. - export PYTHONPATH="$PYTHONPATH:./test/" - # - python -m pytest -v python/ray/common/test/test.py - # - python -m pytest -v python/ray/common/redis_module/runtest.py - # - python -m pytest -v python/ray/plasma/test/test.py - # - python -m pytest -v python/ray/local_scheduler/test/test.py - # - python -m pytest -v python/ray/global_scheduler/test/test.py - - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py - python -m pytest -v python/ray/test/test_ray_init.py @@ -227,7 +145,7 @@ script: - python -m pytest -v test/microbenchmarks.py - python -m pytest -v test/stress_tests.py - python -m pytest -v test/component_failures_test.py - - python test/multi_node_test.py + - python -m pytest -v test/multi_node_test.py - python -m pytest -v test/multi_node_test_2.py - python -m pytest -v test/recursion_test.py - python -m pytest -v test/monitor_test.py diff --git a/.travis/format.sh b/.travis/format.sh index ca92d5196d56..e4c609a85dbd 100755 --- a/.travis/format.sh +++ b/.travis/format.sh @@ -30,7 +30,6 @@ YAPF_EXCLUDES=( '--exclude' 'python/build/*' '--exclude' 'python/ray/pyarrow_files/*' '--exclude' 'python/ray/core/src/ray/gcs/*' - '--exclude' 'python/ray/common/thirdparty/*' ) # Format specified files diff --git a/CMakeLists.txt b/CMakeLists.txt index d02e88a5c420..3980de00477b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,18 +82,15 @@ include_directories(SYSTEM ${PLASMA_INCLUDE_DIR}) include_directories("${CMAKE_CURRENT_LIST_DIR}/src/") add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/ray/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/common/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/plasma/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/local_scheduler/) -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/global_scheduler/) # final target copy_ray add_custom_target(copy_ray ALL) # copy plasma_store_server add_custom_command(TARGET copy_ray POST_BUILD + COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma COMMAND ${CMAKE_COMMAND} -E - copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma) + copy ${ARROW_HOME}/bin/plasma_store_server ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/) if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # add pyarrow as the dependency @@ -102,12 +99,9 @@ if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") # NOTE: The lists below must be kept in sync with ray/python/setup.py. set(ray_file_list - "src/common/thirdparty/redis/src/redis-server" - "src/common/redis_module/libray_redis_module.so" - "src/plasma/plasma_manager" - "src/local_scheduler/local_scheduler" - "src/local_scheduler/liblocal_scheduler_library_python.so" - "src/global_scheduler/global_scheduler" + "src/ray/thirdparty/redis/src/redis-server" + "src/ray/gcs/redis_module/libray_redis_module.so" + "src/ray/raylet/liblocal_scheduler_library_python.so" "src/ray/raylet/raylet_monitor" "src/ray/raylet/raylet") @@ -154,5 +148,6 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") # copy libplasma_java files add_custom_command(TARGET copy_ray POST_BUILD - COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma") + COMMAND bash -c "mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/src/plasma" + COMMAND bash -c "cp ${ARROW_LIBRARY_DIR}/libplasma_java.* ${CMAKE_CURRENT_BINARY_DIR}/src/plasma/") endif() diff --git a/cmake/Modules/ArrowExternalProject.cmake b/cmake/Modules/ArrowExternalProject.cmake index 08d250b81bfd..b57104f1f23f 100644 --- a/cmake/Modules/ArrowExternalProject.cmake +++ b/cmake/Modules/ArrowExternalProject.cmake @@ -15,10 +15,10 @@ # - PLASMA_SHARED_LIB set(arrow_URL https://github.com/apache/arrow.git) -# The PR for this commit is https://github.com/apache/arrow/pull/2792. We +# The PR for this commit is https://github.com/apache/arrow/pull/2826. We # include the link here to make it easier to find the right commit because # Arrow often rewrites git history and invalidates certain commits. -set(arrow_TAG 2d0d3d0dc51999fbaafb15d8b8362a1ef3de2ef7) +set(arrow_TAG b4f7ed6d6ed5cdb6dd136bac3181a438f35c8ea0) set(ARROW_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/external/arrow-install) set(ARROW_HOME ${ARROW_INSTALL_PREFIX}) diff --git a/cmake/Modules/Common.cmake b/cmake/Modules/Common.cmake index cc2a5d5ff992..7d33f13e9d45 100644 --- a/cmake/Modules/Common.cmake +++ b/cmake/Modules/Common.cmake @@ -41,6 +41,3 @@ if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") message (WARNING "NOT FIND JNI") endif() endif() - -include_directories(${CMAKE_SOURCE_DIR}/src/common) -include_directories(${CMAKE_SOURCE_DIR}/src/common/thirdparty) diff --git a/doc/source/actors.rst b/doc/source/actors.rst index c7594592f512..0d8b3c94285b 100644 --- a/doc/source/actors.rst +++ b/doc/source/actors.rst @@ -65,8 +65,7 @@ When ``a1.increment.remote()`` is called, the following events happens. 1. A task is created. 2. The task is assigned directly to the local scheduler responsible for the - actor by the driver's local scheduler. Thus, this scheduling procedure - bypasses the global scheduler. + actor by the driver's local scheduler. 3. An object ID is returned. We can then call ``ray.get`` on the object ID to retrieve the actual value. diff --git a/doc/source/conf.py b/doc/source/conf.py index 27d0c1200d9c..2d212d23b9e8 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -18,44 +18,38 @@ # These lines added to enable Sphinx to work without installing Ray. import mock -MOCK_MODULES = ["gym", - "gym.spaces", - "scipy", - "scipy.signal", - "tensorflow", - "tensorflow.contrib", - "tensorflow.contrib.layers", - "tensorflow.contrib.slim", - "tensorflow.contrib.rnn", - "tensorflow.core", - "tensorflow.core.util", - "tensorflow.python", - "tensorflow.python.client", - "tensorflow.python.util", - "ray.local_scheduler", - "ray.plasma", - "ray.core", - "ray.core.generated", - "ray.core.generated.DriverTableMessage", - "ray.core.generated.LocalSchedulerInfoMessage", - "ray.core.generated.ResultTableReply", - "ray.core.generated.SubscribeToDBClientTableReply", - "ray.core.generated.SubscribeToNotificationsReply", - "ray.core.generated.TaskInfo", - "ray.core.generated.TaskReply", - "ray.core.generated.TaskExecutionDependencies", - "ray.core.generated.ClientTableData", - "ray.core.generated.GcsTableEntry", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ray.protocol.Task", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub",] +MOCK_MODULES = [ + "gym", + "gym.spaces", + "scipy", + "scipy.signal", + "tensorflow", + "tensorflow.contrib", + "tensorflow.contrib.layers", + "tensorflow.contrib.slim", + "tensorflow.contrib.rnn", + "tensorflow.core", + "tensorflow.core.util", + "tensorflow.python", + "tensorflow.python.client", + "tensorflow.python.util", + "ray.raylet", + "ray.plasma", + "ray.core", + "ray.core.generated", + "ray.core.generated.ClientTableData", + "ray.core.generated.GcsTableEntry", + "ray.core.generated.HeartbeatTableData", + "ray.core.generated.DriverTableData", + "ray.core.generated.ErrorTableData", + "ray.core.generated.ProfileTableData", + "ray.core.generated.ObjectTableData", + "ray.core.generated.ray.protocol.Task", + "ray.core.generated.TablePrefix", + "ray.core.generated.TablePubsub", +] for mod_name in MOCK_MODULES: - sys.modules[mod_name] = mock.Mock() + sys.modules[mod_name] = mock.Mock() # ray.rllib.models.action_dist.py and # ray.rllib.models.lstm.py will use tf.VERSION sys.modules["tensorflow"].VERSION = "9.9.9" @@ -89,7 +83,7 @@ source_suffix = ['.rst', '.md'] source_parsers = { - '.md': CommonMarkParser, + '.md': CommonMarkParser, } # The encoding of source files. @@ -259,25 +253,24 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # Additional stuff for the LaTeX preamble. + #'preamble': '', -# Latex figure (float) alignment -#'figure_align': 'htbp', + # Latex figure (float) alignment + #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'Ray.tex', u'Ray Documentation', - u'The Ray Team', 'manual'), + (master_doc, 'Ray.tex', u'Ray Documentation', u'The Ray Team', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -300,29 +293,23 @@ # If false, no module index is generated. #latex_domain_indices = True - # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'ray', u'Ray Documentation', - [author], 1) -] +man_pages = [(master_doc, 'ray', u'Ray Documentation', [author], 1)] # If true, show URL addresses after external links. #man_show_urls = False - # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Ray', u'Ray Documentation', - author, 'Ray', 'One line description of project.', - 'Miscellaneous'), + (master_doc, 'Ray', u'Ray Documentation', author, 'Ray', + 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. diff --git a/doc/source/fault-tolerance.rst b/doc/source/fault-tolerance.rst index a4692f904fee..112927ce34ee 100644 --- a/doc/source/fault-tolerance.rst +++ b/doc/source/fault-tolerance.rst @@ -47,7 +47,7 @@ Process Failures ~~~~~~~~~~~~~~~~ 1. Ray does not recover from the failure of any of the following processes: - a Redis server, the global scheduler, the monitor process. + a Redis server and the monitor process. 2. If a driver fails, that driver will not be restarted and the job will not complete. diff --git a/doc/source/internals-overview.rst b/doc/source/internals-overview.rst index 69ac1895a55c..a2516de1d10c 100644 --- a/doc/source/internals-overview.rst +++ b/doc/source/internals-overview.rst @@ -15,8 +15,8 @@ Running Ray standalone Ray can be used standalone by calling ``ray.init()`` within a script. When the call to ``ray.init()`` happens, all of the relevant processes are started. -These include a local scheduler, a global scheduler, an object store and -manager, a Redis server, and a number of worker processes. +These include a local scheduler, an object store and manager, a Redis server, +and a number of worker processes. When the script exits, these processes will be killed. @@ -112,7 +112,7 @@ When a driver or worker invokes a remote function, a number of things happen. - The task object is then sent to the local scheduler on the same node as the driver or worker. - The local scheduler makes a decision to either schedule the task locally or to - pass the task on to a global scheduler. + pass the task on to another local scheduler. - If all of the task's object dependencies are present in the local object store and there are enough CPU and GPU resources available to execute the diff --git a/doc/source/tempfile.rst b/doc/source/tempfile.rst index 7e489348c4bf..99daf2833ee2 100644 --- a/doc/source/tempfile.rst +++ b/doc/source/tempfile.rst @@ -45,8 +45,6 @@ A typical layout of temporary files could look like this: │   ├── log_monitor.out │   ├── monitor.err │   ├── monitor.out - │   ├── plasma_manager_0.err # array of plasma managers' outputs - │   ├── plasma_manager_0.out │   ├── plasma_store_0.err # array of plasma stores' outputs │   ├── plasma_store_0.out │   ├── raylet_0.err # array of raylets' outputs. Control it with `--no-redirect-worker-output` (in Ray's command line) or `redirect_worker_output` (in ray.init()) diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index 0493b6916990..81de87a571ce 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -9,7 +9,7 @@ To use Ray, you need to understand the following: Overview -------- -Ray is a Python-based distributed execution engine. The same code can be run on +Ray is a distributed execution engine. The same code can be run on a single machine to achieve efficient multiprocessing, and it can be used on a cluster for large computations. @@ -21,8 +21,6 @@ When using Ray, several processes are involved. allows workers to efficiently share objects on the same node with minimal copying and deserialization. - One **local scheduler** per node assigns tasks to workers on the same node. -- A **global scheduler** receives tasks from local schedulers and assigns them - to other local schedulers. - A **driver** is the Python process that the user controls. For example, if the user is running a script or using a Python shell, then the driver is the Python process that runs the script or the shell. A driver is similar to a worker in diff --git a/doc/source/using-ray-on-a-cluster.rst b/doc/source/using-ray-on-a-cluster.rst index 29c2585ac7cf..611e47b79db2 100644 --- a/doc/source/using-ray-on-a-cluster.rst +++ b/doc/source/using-ray-on-a-cluster.rst @@ -51,7 +51,6 @@ Now we've started all of the Ray processes on each node Ray. This includes - An object store on each machine. - A local scheduler on each machine. - Multiple Redis servers (on the head node). -- One global scheduler (on the head node). To run some commands, start up Python on one of the nodes in the cluster, and do the following. diff --git a/doc/source/using-ray-on-a-large-cluster.rst b/doc/source/using-ray-on-a-large-cluster.rst index c3d6d8a8d238..b87c8c05f512 100644 --- a/doc/source/using-ray-on-a-large-cluster.rst +++ b/doc/source/using-ray-on-a-large-cluster.rst @@ -154,7 +154,6 @@ Now you have started all of the Ray processes on each node. These include: - An object store on each machine. - A local scheduler on each machine. - Multiple Redis servers (on the head node). -- One global scheduler (on the head node). To confirm that the Ray cluster setup is working, start up Python on one of the nodes in the cluster and enter the following commands to connect to the Ray diff --git a/java/checkstyle-suppressions.xml b/java/checkstyle-suppressions.xml index 619c24e1466f..0422332258df 100644 --- a/java/checkstyle-suppressions.xml +++ b/java/checkstyle-suppressions.xml @@ -10,5 +10,5 @@ - + diff --git a/java/prepare.sh b/java/prepare.sh index 807301a74edb..9554e500a8ed 100755 --- a/java/prepare.sh +++ b/java/prepare.sh @@ -42,15 +42,15 @@ fi # echo "ray_dir = $ray_dir" declare -a nativeBinaries=( - "./src/common/thirdparty/redis/src/redis-server" + "./src/ray/thirdparty/redis/src/redis-server" "./src/plasma/plasma_store_server" "./src/ray/raylet/raylet" "./src/ray/raylet/raylet_monitor" ) declare -a nativeLibraries=( - "./src/common/redis_module/libray_redis_module.so" - "./src/local_scheduler/liblocal_scheduler_library_java.*" + "./src/ray/gcs/redis_module/libray_redis_module.so" + "./src/ray/raylet/liblocal_scheduler_library_java.*" "./src/plasma/libplasma_java.*" "./src/ray/raylet/*lib.a" ) diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index e07b9e89e9e7..d4d90f24ece2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -165,12 +165,12 @@ public RayConfig(Config config) { // library path this.libraryPath = new ImmutableList.Builder().add( rayHome + "/build/src/plasma", - rayHome + "/build/src/local_scheduler" + rayHome + "/build/src/ray/raylet" ).addAll(customLibraryPath).build(); redisServerExecutablePath = rayHome + - "/build/src/common/thirdparty/redis/src/redis-server"; - redisModulePath = rayHome + "/build/src/common/redis_module/libray_redis_module.so"; + "/build/src/ray/thirdparty/redis/src/redis-server"; + redisModulePath = rayHome + "/build/src/ray/gcs/redis_module/libray_redis_module.so"; plasmaStoreExecutablePath = rayHome + "/build/src/plasma/plasma_store_server"; rayletExecutablePath = rayHome + "/build/src/ray/raylet/raylet"; diff --git a/java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java b/java/runtime/src/main/java/org/ray/runtime/generated/Language.java similarity index 51% rename from java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java rename to java/runtime/src/main/java/org/ray/runtime/generated/Language.java index e5e53614aa8a..34604374dd44 100644 --- a/java/runtime/src/main/java/org/ray/runtime/generated/TaskLanguage.java +++ b/java/runtime/src/main/java/org/ray/runtime/generated/Language.java @@ -2,13 +2,13 @@ package org.ray.runtime.generated; -public final class TaskLanguage { - private TaskLanguage() { } +public final class Language { + private Language() { } public static final int PYTHON = 0; - public static final int JAVA = 1; + public static final int CPP = 1; + public static final int JAVA = 2; - public static final String[] names = { "PYTHON", "JAVA", }; + public static final String[] names = { "PYTHON", "CPP", "JAVA", }; public static String name(int e) { return names[e]; } } - diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index b84fe22db0ac..28f0cd97ce19 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -13,9 +13,9 @@ import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.generated.Arg; +import org.ray.runtime.generated.Language; import org.ray.runtime.generated.ResourcePair; import org.ray.runtime.generated.TaskInfo; -import org.ray.runtime.generated.TaskLanguage; import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.UniqueIdUtil; @@ -229,7 +229,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { actorIdOffset, actorHandleIdOffset, actorCounter, false, functionIdOffset, argsOffset, returnsOffset, requiredResourcesOffset, - requiredPlacementResourcesOffset, TaskLanguage.JAVA, + requiredPlacementResourcesOffset, Language.JAVA, functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); @@ -256,8 +256,8 @@ public void destroy() { /// 1) pushd $Dir/java/runtime/target/classes /// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl /// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h - /// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/local_scheduler/lib/java/ - /// 5) vim $Dir/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc + /// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/ray/raylet/lib/java/ + /// 5) vim $Dir/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc /// 6) popd private static native long nativeInit(String localSchedulerSocket, byte[] workerId, diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index ac7e01124632..71e3d0dfff8e 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -23,7 +23,7 @@ public void testCreateRayConfig() { Assert.assertEquals(System.getProperty("user.dir"), rayConfig.rayHome); Assert.assertEquals(System.getProperty("user.dir") + - "/build/src/common/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); + "/build/src/ray/thirdparty/redis/src/redis-server", rayConfig.redisServerExecutablePath); Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); diff --git a/java/tutorial/pom.xml b/java/tutorial/pom.xml index 198f6f0a3a51..48a03dc1ca8e 100644 --- a/java/tutorial/pom.xml +++ b/java/tutorial/pom.xml @@ -40,7 +40,7 @@ ${basedir}/../ray.config.ini -ea - -Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/local_scheduler + -Djava.library.path=${basedir}/../../build/src/plasma:${basedir}/../../build/src/ray/raylet -noverify -DlogOutput=console diff --git a/python/ray/__init__.py b/python/ray/__init__.py index b97af4b587da..95255bba1288 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -46,7 +46,7 @@ e.args += (helpful_message, ) raise -from ray.local_scheduler import ObjectID, _config # noqa: E402 +from ray.raylet import ObjectID, _config # noqa: E402 from ray.profiling import profile # noqa: E402 from ray.worker import (error_info, init, connect, disconnect, get, put, wait, remote, get_gpu_ids, get_resource_ids, get_webui_url, diff --git a/python/ray/actor.py b/python/ray/actor.py index d1f034cc6057..86b87e7d564c 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -9,7 +9,7 @@ import ray.cloudpickle as pickle from ray.function_manager import FunctionActorManager -import ray.local_scheduler +import ray.raylet import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker diff --git a/python/ray/common/redis_module/.gitkeep b/python/ray/common/redis_module/.gitkeep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py deleted file mode 100644 index 7a7d25c6bedc..000000000000 --- a/python/ray/common/redis_module/runtest.py +++ /dev/null @@ -1,451 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import redis -import sys -import time -import unittest - -import ray.gcs_utils -import ray.services - - -def integerToAsciiHex(num, numbytes): - retstr = b"" - # Support 32 and 64 bit architecture. - assert (numbytes == 4 or numbytes == 8) - for i in range(numbytes): - curbyte = num & 0xff - if sys.version_info >= (3, 0): - retstr += bytes([curbyte]) - else: - retstr += chr(curbyte) - num = num >> 8 - - return retstr - - -def get_next_message(pubsub_client, timeout_seconds=10): - """Block until the next message is available on the pubsub channel.""" - start_time = time.time() - while True: - message = pubsub_client.get_message() - if message is not None: - return message - time.sleep(0.1) - if time.time() - start_time > timeout_seconds: - raise Exception("Timed out while waiting for next message.") - - -class TestGlobalStateStore(unittest.TestCase): - def setUp(self): - unused_primary_redis_addr, redis_shards = ray.services.start_redis( - "localhost", use_credis="RAY_USE_NEW_GCS" in os.environ) - self.redis = redis.StrictRedis( - host="localhost", port=redis_shards[0].split(":")[-1], db=0) - - def tearDown(self): - ray.services.cleanup() - - def testInvalidObjectTableAdd(self): - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called - # with the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - "one", "hash2", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1, - "hash2", "manager_id1", - "extra argument") - # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an - # object ID that is already present with a different hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1"}) - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - # Check that the second manager was added, even though the hash was - # mismatched. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that it is fine if we add the same object ID multiple times - # with the most recent hash. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash2", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, - "hash2", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - - def testObjectTableAddAndLookup(self): - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not - # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Add a manager that already exists again and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Check that we properly handle NULL characters. In the past, NULL - # characters were handled improperly causing a "hash mismatch" error if - # two object IDs that agreed up to the NULL character were inserted - # with different hashes. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1, - "hash2", "manager_id1") - # Check that NULL characters in the hash are handled properly. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash1", "manager_id1") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1, - "\x00hash2", "manager_id1") - - def testObjectTableAddAndRemove(self): - # Try removing a manager from an object ID that has not been added yet. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not - # been added yet. - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(response, None) - # Add some managers and try again. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that doesn't exist, and make sure we still have the - # same set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id1", b"manager_id2"}) - # Remove a manager that does exist. Make sure it gets removed the first - # time and does nothing the second time. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id1") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), {b"manager_id2"}) - # Remove the last manager, and make sure we have an empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id2") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) - # Remove a manager from an empty set, and make sure we now have an - # empty set. - self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", - "manager_id3") - response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", - "object_id1") - self.assertEqual(set(response), set()) - - def testObjectTableSubscribeToNotifications(self): - # Define a helper method for checking the contents of object - # notifications. - def check_object_notification(notification_message, object_id, - object_size, manager_ids): - notification_object = (ray.gcs_utils.SubscribeToNotificationsReply. - GetRootAsSubscribeToNotificationsReply( - notification_message, 0)) - self.assertEqual(notification_object.ObjectId(), object_id) - self.assertEqual(notification_object.ObjectSize(), object_size) - self.assertEqual(notification_object.ManagerIdsLength(), - len(manager_ids)) - for i in range(len(manager_ids)): - self.assertEqual( - notification_object.ManagerIds(i), manager_ids[i]) - - data_size = 0xf1f0 - p = self.redis.pubsub() - # Subscribe to an object ID. - p.psubscribe("{}manager_id1".format( - ray.gcs_utils.OBJECT_CHANNEL_PREFIX)) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", - data_size, "hash1", "manager_id2") - # Receive the acknowledgement message. - self.assertEqual(get_next_message(p)["data"], 1) - # Request a notification and receive the data. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id1") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id1", data_size, - [b"manager_id2"]) - - # Request a notification for an object that isn't there. Then add the - # object and receive the data. Only the first call to - # RAY.OBJECT_TABLE_ADD should trigger notifications. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id2", "object_id3") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id1") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id2") - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", - data_size, "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1"]) - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", - data_size, "hash1", "manager_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id2", data_size, - [b"manager_id3"]) - # Request notifications for object_id3 again. - self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", - "manager_id1", "object_id3") - # Verify that the notification is correct. - check_object_notification( - get_next_message(p)["data"], b"object_id3", data_size, - [b"manager_id1", b"manager_id2", b"manager_id3"]) - - def testResultTableAddAndLookup(self): - def check_result_table_entry(message, task_id, is_put): - result_table_reply = ( - ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( - message, 0)) - self.assertEqual(result_table_reply.TaskId(), task_id) - self.assertEqual(result_table_reply.IsPut(), is_put) - - # Try looking up something in the result table before anything is - # added. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Adding the object to the object table should have no effect. - self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, - "hash1", "manager_id1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - self.assertIsNone(response) - # Add the result to the result table. The lookup now returns the task - # ID. - task_id = b"task_id1" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", - task_id, 0) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Doing it again should still work. - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id1") - check_result_table_entry(response, task_id, False) - # Try another result table lookup. This should succeed. - task_id = b"task_id2" - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", - task_id, 1) - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", - "object_id2") - check_result_table_entry(response, task_id, True) - - def testInvalidTaskTableAdd(self): - # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called - # with the wrong arguments. - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello") - with self.assertRaises(redis.ResponseError): - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, - "node_id") - with self.assertRaises(redis.ResponseError): - # Non-integer scheduling states should not be added. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - "invalid_state", "node_id", "task_spec") - with self.assertRaises(redis.ResponseError): - # Should not be able to update a non-existent task. - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, - "node_id", b"") - - def testTaskTableAddAndLookup(self): - TASK_STATUS_WAITING = 1 - TASK_STATUS_SCHEDULED = 2 - TASK_STATUS_QUEUED = 4 - - # make sure somebody will get a notification (checked in the redis - # module) - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - - def check_task_reply(message, task_args, updated=False): - (task_status, local_scheduler_id, execution_dependencies_string, - spillback_count, task_spec) = task_args - task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - message, 0) - self.assertEqual(task_reply_object.State(), task_status) - self.assertEqual(task_reply_object.LocalSchedulerId(), - local_scheduler_id) - self.assertEqual(task_reply_object.SpillbackCount(), - spillback_count) - self.assertEqual(task_reply_object.TaskSpec(), task_spec) - self.assertEqual(task_reply_object.Updated(), updated) - - # Check that task table adds, updates, and lookups work correctly. - task_args = [TASK_STATUS_WAITING, b"node_id", b"", 0, b"task_spec"] - response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", - *task_args) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) - - task_args[0] = TASK_STATUS_SCHEDULED - self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", - *task_args[:4]) - response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - check_task_reply(response, task_args) - - # If the current value, test value, and set value are all the same, the - # update happens, and the response is still the same task. - task_args = [task_args[0]] + task_args - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the task entry is still the same. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - check_task_reply(get_response, task_args[1:]) - - # If the current value is the same as the test value, and the set value - # is different, the update happens, and the response is the entire - # task. - task_args[1] = TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - # Check that the update happened. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - check_task_reply(get_response, task_args[1:]) - - # If the current value is no longer the same as the test value, the - # response is the same task as before the test-and-set. - new_task_args = task_args[:] - new_task_args[1] = TASK_STATUS_WAITING - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - self.assertEqual(get_response2, get_response) - - # If the test value is a bitmask that matches the current value, the - # update happens. - task_args = new_task_args - task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *task_args[:3]) - check_task_reply(response, task_args[1:], updated=True) - - # If the test value is a bitmask that does not match the current value, - # the update does not happen, and the response is the same task as - # before the test-and-set. - new_task_args = task_args[:] - new_task_args[0] = TASK_STATUS_SCHEDULED - old_response = response - response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", - "task_id", *new_task_args[:3]) - check_task_reply(response, task_args[1:], updated=False) - # Check that the update did not happen. - get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", - "task_id") - self.assertNotEqual(get_response, old_response) - check_task_reply(get_response, task_args[1:]) - - def check_task_subscription(self, p, scheduling_state, local_scheduler_id): - task_args = [ - b"task_id", scheduling_state, - local_scheduler_id.encode("ascii"), b"", 0, b"task_spec" - ] - self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) - # Receive the data. - message = get_next_message(p)["data"] - # Check that the notification object is correct. - notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - message, 0) - self.assertEqual(notification_object.TaskId(), task_args[0]) - self.assertEqual(notification_object.State(), task_args[1]) - self.assertEqual(notification_object.LocalSchedulerId(), task_args[2]) - self.assertEqual(notification_object.ExecutionDependencies(), - task_args[3]) - self.assertEqual(notification_object.TaskSpec(), task_args[-1]) - - def testTaskTableSubscribe(self): - scheduling_state = 1 - local_scheduler_id = "local_scheduler_id" - # Subscribe to the task table. - p = self.redis.pubsub() - p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - # unsubscribe to make sure there is only one subscriber at a given time - p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - p.psubscribe("{prefix}*:{state}".format( - prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}*:{state}".format( - prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - p.psubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=ray.gcs_utils.TASK_PREFIX, - local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 1) - self.check_task_subscription(p, scheduling_state, local_scheduler_id) - p.punsubscribe("{prefix}{local_scheduler_id}:*".format( - prefix=ray.gcs_utils.TASK_PREFIX, - local_scheduler_id=local_scheduler_id)) - # Receive acknowledgment. - self.assertEqual(get_next_message(p)["data"], 0) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/python/ray/common/test/test.py b/python/ray/common/test/test.py deleted file mode 100644 index cd36b697bbaa..000000000000 --- a/python/ray/common/test/test.py +++ /dev/null @@ -1,181 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import pickle -import sys -import unittest - -import ray.local_scheduler as local_scheduler -import ray.ray_constants as ray_constants - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -BASE_SIMPLE_OBJECTS = [ - 0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"", - 990 * u"h", - np.ones(3), - np.array([True, False]), None, True, False -] - -if sys.version_info < (3, 0): - BASE_SIMPLE_OBJECTS += [ - long(0), # noqa: E501,F821 - long(1), # noqa: E501,F821 - long(100000), # noqa: E501,F821 - long(1 << 100) # noqa: E501,F821 - ] - -LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS] -TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS] -DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS] - -SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS + - TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS) - -# Create some complex objects that cannot be serialized by value in tasks. - -lst = [] -lst.append(lst) - - -class Foo(object): - def __init__(self): - pass - - -BASE_COMPLEX_OBJECTS = [ - 15000 * "h", 15000 * u"h", lst, - Foo(), 100 * [100 * [10 * [1]]], - np.array([Foo()]) -] - -LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS] -TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS] -DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS] - -COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS + - TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS) - - -class TestSerialization(unittest.TestCase): - def test_serialize_by_value(self): - - for val in SIMPLE_OBJECTS: - self.assertTrue(local_scheduler.check_simple_value(val)) - for val in COMPLEX_OBJECTS: - self.assertFalse(local_scheduler.check_simple_value(val)) - - -class TestObjectID(unittest.TestCase): - def test_create_object_id(self): - random_object_id() - - def test_cannot_pickle_object_ids(self): - object_ids = [random_object_id() for _ in range(256)] - - def f(): - return object_ids - - def g(val=object_ids): - return 1 - - def h(): - object_ids[0] - return 1 - - # Make sure that object IDs cannot be pickled (including functions that - # close over object IDs). - self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0])) - self.assertRaises(Exception, lambda: pickle.dumps(object_ids)) - self.assertRaises(Exception, lambda: pickle.dumps(f)) - self.assertRaises(Exception, lambda: pickle.dumps(g)) - self.assertRaises(Exception, lambda: pickle.dumps(h)) - - def test_equality_comparisons(self): - x1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a") - x2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"a") - y1 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b") - y2 = local_scheduler.ObjectID(ray_constants.ID_SIZE * b"b") - self.assertEqual(x1, x2) - self.assertEqual(y1, y2) - self.assertNotEqual(x1, y1) - - random_strings = [ - np.random.bytes(ray_constants.ID_SIZE) for _ in range(256) - ] - object_ids1 = [ - local_scheduler.ObjectID(random_strings[i]) for i in range(256) - ] - object_ids2 = [ - local_scheduler.ObjectID(random_strings[i]) for i in range(256) - ] - self.assertEqual(len(set(object_ids1)), 256) - self.assertEqual(len(set(object_ids1 + object_ids2)), 256) - self.assertEqual(set(object_ids1), set(object_ids2)) - - def test_hashability(self): - x = random_object_id() - y = random_object_id() - {x: y} - {x, y} - - -class TestTask(unittest.TestCase): - def check_task(self, task, function_id, num_return_vals, args): - self.assertEqual(function_id.id(), task.function_id().id()) - retrieved_args = task.arguments() - self.assertEqual(num_return_vals, len(task.returns())) - self.assertEqual(len(args), len(retrieved_args)) - for i in range(len(retrieved_args)): - if isinstance(retrieved_args[i], local_scheduler.ObjectID): - self.assertEqual(retrieved_args[i].id(), args[i].id()) - else: - self.assertEqual(retrieved_args[i], args[i]) - - def test_create_and_serialize_task(self): - # TODO(rkn): The function ID should be a FunctionID object, not an - # ObjectID. - driver_id = random_driver_id() - parent_id = random_task_id() - function_id = random_function_id() - object_ids = [random_object_id() for _ in range(256)] - args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"], - 10 * ["a"], 100 * ["a"], 1000 * ["a"], [ - 1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2] - ], object_ids[:1], object_ids[:2], object_ids[:3], - object_ids[:4], object_ids[:5], object_ids[:10], - object_ids[:100], object_ids[:256], [1, object_ids[0]], [ - object_ids[0], "a" - ], [1, object_ids[0], "a"], [ - object_ids[0], 1, object_ids[1], "a" - ], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids] - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(driver_id, function_id, args, - num_return_vals, parent_id, 0) - self.check_task(task, function_id, num_return_vals, args) - data = local_scheduler.task_to_string(task) - task2 = local_scheduler.task_from_string(data) - self.check_task(task2, function_id, num_return_vals, args) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/python/ray/common/thirdparty/redis/src/.gitkeep b/python/ray/common/thirdparty/redis/src/.gitkeep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/common/__init__.py b/python/ray/core/src/ray/__init__.py similarity index 100% rename from python/ray/common/__init__.py rename to python/ray/core/src/ray/__init__.py diff --git a/python/ray/core/src/local_scheduler/__init__.py b/python/ray/core/src/ray/raylet/__init__.py similarity index 100% rename from python/ray/core/src/local_scheduler/__init__.py rename to python/ray/core/src/ray/raylet/__init__.py diff --git a/python/ray/experimental/sgd/sgd.py b/python/ray/experimental/sgd/sgd.py index c569c036f1b1..82c61613a9fc 100644 --- a/python/ray/experimental/sgd/sgd.py +++ b/python/ray/experimental/sgd/sgd.py @@ -108,8 +108,6 @@ def __init__(self, if plasma_op: store_socket = ( ray.worker.global_worker.plasma_client.store_socket_name) - manager_socket = ( - ray.worker.global_worker.plasma_client.manager_socket_name) if not plasma.tf_plasma_op: plasma.build_plasma_tensorflow_op() @@ -130,7 +128,7 @@ def __init__(self, [grad], self.plasma_in_grads_oids[j], plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) + plasma_manager_socket_name="") self.plasma_in_grads.append(plasma_grad) # For applying grads <- plasma @@ -149,7 +147,7 @@ def __init__(self, self.plasma_out_grads_oids[j], dtype=tf.float32, plasma_store_socket_name=store_socket, - plasma_manager_socket_name=manager_socket) + plasma_manager_socket_name="") grad_ph = tf.reshape(grad_ph, self.packed_grads_and_vars[0][j][0].shape) logger.debug("Packed tensor {}".format(grad_ph)) diff --git a/python/ray/experimental/sgd/util.py b/python/ray/experimental/sgd/util.py index ca72bb5e9ef4..6c4f89719402 100644 --- a/python/ray/experimental/sgd/util.py +++ b/python/ray/experimental/sgd/util.py @@ -14,15 +14,10 @@ def fetch(oids): - if ray.global_state.use_raylet: - local_sched_client = ray.worker.global_worker.local_scheduler_client - for o in oids: - ray_obj_id = ray.ObjectID(o) - local_sched_client.reconstruct_objects([ray_obj_id], True) - else: - for o in oids: - plasma_id = ray.pyarrow.plasma.ObjectID(o) - ray.worker.global_worker.plasma_client.fetch([plasma_id]) + local_sched_client = ray.worker.global_worker.local_scheduler_client + for o in oids: + ray_obj_id = ray.ObjectID(o) + local_sched_client.reconstruct_objects([ray_obj_id], True) def run_timeline(sess, ops, feed_dict=None, write_timeline=False, name=""): diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 9f1215a7e988..cf3d182838e3 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -6,8 +6,6 @@ from collections import defaultdict import heapq import json -import numbers -import os import redis import sys import time @@ -18,25 +16,6 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) -# This mapping from integer to task state string must be kept up-to-date with -# the scheduling_state enum in task.h. -TASK_STATUS_WAITING = 1 -TASK_STATUS_SCHEDULED = 2 -TASK_STATUS_QUEUED = 4 -TASK_STATUS_RUNNING = 8 -TASK_STATUS_DONE = 16 -TASK_STATUS_LOST = 32 -TASK_STATUS_RECONSTRUCTING = 64 -TASK_STATUS_MAPPING = { - TASK_STATUS_WAITING: "WAITING", - TASK_STATUS_SCHEDULED: "SCHEDULED", - TASK_STATUS_QUEUED: "QUEUED", - TASK_STATUS_RUNNING: "RUNNING", - TASK_STATUS_DONE: "DONE", - TASK_STATUS_LOST: "LOST", - TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING", -} - class GlobalState(object): """A class used to interface with the Ray control state. @@ -47,7 +26,6 @@ class GlobalState(object): Attributes: redis_client: The Redis client used to query the primary redis server. redis_clients: Redis clients for each of the Redis shards. - use_raylet: True if we are using the raylet code path. """ def __init__(self): @@ -57,8 +35,6 @@ def __init__(self): self.redis_client = None # Clients for the redis shards, storing the object table & task table. self.redis_clients = None - # True if we are using the raylet code path and false otherwise. - self.use_raylet = None def _check_connected(self): """Check that the object has been initialized before it is used. @@ -130,18 +106,6 @@ def _initialize_global_state(self, "ip_address_ports = {}".format( num_redis_shards, ip_address_ports)) - use_raylet = self.redis_client.get("UseRaylet") - if use_raylet is not None: - self.use_raylet = bool(int(use_raylet)) - elif os.environ.get("RAY_USE_XRAY") == "0": - # This environment variable is used in our testing setup. - print("Detected environment variable 'RAY_USE_XRAY' with value " - "{}. This turns OFF xray.".format( - os.environ.get("RAY_USE_XRAY"))) - self.use_raylet = False - else: - self.use_raylet = True - # Get the rest of the information. self.redis_clients = [] for ip_address_port in ip_address_ports: @@ -195,51 +159,23 @@ def _object_table(self, object_id): object_id = ray.ObjectID(hex_to_binary(object_id)) # Return information about a single object ID. - if not self.use_raylet: - # Use the non-raylet code path. - object_locations = self._execute_command( - object_id, "RAY.OBJECT_TABLE_LOOKUP", object_id.id()) - if object_locations is not None: - manager_ids = [ - binary_to_hex(manager_id) - for manager_id in object_locations - ] - else: - manager_ids = None - - result_table_response = self._execute_command( - object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id()) - result_table_message = ( - ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply( - result_table_response, 0)) - - result = { - "ManagerIDs": manager_ids, - "TaskID": binary_to_hex(result_table_message.TaskId()), - "IsPut": bool(result_table_message.IsPut()), - "DataSize": result_table_message.DataSize(), - "Hash": binary_to_hex(result_table_message.Hash()) - } + message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.OBJECT, "", + object_id.id()) + result = [] + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) - else: - # Use the raylet code path. - message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, - "", object_id.id()) - result = [] - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - for i in range(gcs_entry.EntriesLength()): - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(i), 0) - object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), - "IsEviction": entry.IsEviction(), - "NumEvictions": entry.NumEvictions() - } - result.append(object_info) + for i in range(gcs_entry.EntriesLength()): + entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( + gcs_entry.Entries(i), 0) + object_info = { + "DataSize": entry.ObjectSize(), + "Manager": entry.Manager(), + "IsEviction": entry.IsEviction(), + "NumEvictions": entry.NumEvictions() + } + result.append(object_info) return result @@ -259,25 +195,12 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - if not self.use_raylet: - object_info_keys = self._keys( - ray.gcs_utils.OBJECT_INFO_PREFIX + "*") - object_location_keys = self._keys( - ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*") - object_ids_binary = set([ - key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):] - for key in object_info_keys - ] + [ - key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):] - for key in object_location_keys - ]) - else: - object_keys = self._keys( - ray.gcs_utils.TablePrefix_OBJECT_string + "*") - object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] - for key in object_keys - } + object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + + "*") + object_ids_binary = { + key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + for key in object_keys + } results = {} for object_id_binary in object_ids_binary: @@ -294,21 +217,21 @@ def _task_table(self, task_id): Returns: A dictionary with information about the task ID in question. - TASK_STATUS_MAPPING should be used to parse the "State" field - into a human-readable string. """ - if not self.use_raylet: - # Use the non-raylet code path. - task_table_response = self._execute_command( - task_id, "RAY.TASK_TABLE_GET", task_id.id()) - if task_table_response is None: - raise Exception("There is no entry for task ID {} in the task " - "table.".format(binary_to_hex(task_id.id()))) - task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply( - task_table_response, 0) - task_spec = task_table_message.TaskSpec() - task_spec = ray.local_scheduler.task_from_string(task_spec) + message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + "", task_id.id()) + gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + info = [] + for i in range(gcs_entries.EntriesLength()): + task_table_message = ray.gcs_utils.Task.GetRootAsTask( + gcs_entries.Entries(i), 0) + + execution_spec = task_table_message.TaskExecutionSpec() + task_spec = task_table_message.TaskSpecification() + task_spec = ray.raylet.task_from_string(task_spec) task_spec_info = { "DriverID": binary_to_hex(task_spec.driver_id().id()), "TaskID": binary_to_hex(task_spec.task_id().id()), @@ -326,80 +249,19 @@ def _task_table(self, task_id): "RequiredResources": task_spec.required_resources() } - execution_dependencies_message = ( - ray.gcs_utils.TaskExecutionDependencies. - GetRootAsTaskExecutionDependencies( - task_table_message.ExecutionDependencies(), 0)) - execution_dependencies = [ - ray.ObjectID( - execution_dependencies_message.ExecutionDependencies(i)) - for i in range(execution_dependencies_message. - ExecutionDependenciesLength()) - ] - - # TODO(rkn): The return fields ExecutionDependenciesString and - # ExecutionDependencies are redundant, so we should remove - # ExecutionDependencies. However, it is currently used in - # monitor.py. - - return { - "State": task_table_message.State(), - "LocalSchedulerID": binary_to_hex( - task_table_message.LocalSchedulerId()), - "ExecutionDependenciesString": task_table_message. - ExecutionDependencies(), - "ExecutionDependencies": execution_dependencies, - "SpillbackCount": task_table_message.SpillbackCount(), + info.append({ + "ExecutionSpec": { + "Dependencies": [ + execution_spec.Dependencies(i) + for i in range(execution_spec.DependenciesLength()) + ], + "LastTimestamp": execution_spec.LastTimestamp(), + "NumForwards": execution_spec.NumForwards() + }, "TaskSpec": task_spec_info - } - - else: - # Use the raylet code path. - message = self._execute_command( - task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, "", task_id.id()) - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - info = [] - for i in range(gcs_entries.EntriesLength()): - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(i), 0) - - execution_spec = task_table_message.TaskExecutionSpec() - task_spec = task_table_message.TaskSpecification() - task_spec = ray.local_scheduler.task_from_string(task_spec) - task_spec_info = { - "DriverID": binary_to_hex(task_spec.driver_id().id()), - "TaskID": binary_to_hex(task_spec.task_id().id()), - "ParentTaskID": binary_to_hex( - task_spec.parent_task_id().id()), - "ParentCounter": task_spec.parent_counter(), - "ActorID": binary_to_hex(task_spec.actor_id().id()), - "ActorCreationID": binary_to_hex( - task_spec.actor_creation_id().id()), - "ActorCreationDummyObjectID": binary_to_hex( - task_spec.actor_creation_dummy_object_id().id()), - "ActorCounter": task_spec.actor_counter(), - "FunctionID": binary_to_hex(task_spec.function_id().id()), - "Args": task_spec.arguments(), - "ReturnObjectIDs": task_spec.returns(), - "RequiredResources": task_spec.required_resources() - } - - info.append({ - "ExecutionSpec": { - "Dependencies": [ - execution_spec.Dependencies(i) - for i in range(execution_spec.DependenciesLength()) - ], - "LastTimestamp": execution_spec.LastTimestamp(), - "NumForwards": execution_spec.NumForwards() - }, - "TaskSpec": task_spec_info - }) + }) - return info + return info def task_table(self, task_id=None): """Fetch and parse the task table information for one or more task IDs. @@ -416,19 +278,12 @@ def task_table(self, task_id=None): task_id = ray.ObjectID(hex_to_binary(task_id)) return self._task_table(task_id) else: - if not self.use_raylet: - task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*") - task_ids_binary = [ - key[len(ray.gcs_utils.TASK_PREFIX):] - for key in task_table_keys - ] - else: - task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") - task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] - for key in task_table_keys - ] + task_table_keys = self._keys( + ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + task_ids_binary = [ + key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + for key in task_table_keys + ] results = {} for task_id_binary in task_ids_binary: @@ -464,95 +319,54 @@ def client_table(self): Information about the Ray clients in the cluster. """ self._check_connected() - if not self.use_raylet: - db_client_keys = self.redis_client.keys( - ray.gcs_utils.DB_CLIENT_PREFIX + "*") - node_info = {} - for key in db_client_keys: - client_info = self.redis_client.hgetall(key) - node_ip_address = decode(client_info[b"node_ip_address"]) - if node_ip_address not in node_info: - node_info[node_ip_address] = [] - client_info_parsed = {} - assert b"client_type" in client_info - assert b"deleted" in client_info - assert b"ray_client_id" in client_info - for field, value in client_info.items(): - if field == b"node_ip_address": - pass - elif field == b"client_type": - client_info_parsed["ClientType"] = decode(value) - elif field == b"deleted": - client_info_parsed["Deleted"] = bool( - int(decode(value))) - elif field == b"ray_client_id": - client_info_parsed["DBClientID"] = binary_to_hex(value) - elif field == b"manager_address": - client_info_parsed["AuxAddress"] = decode(value) - elif field == b"local_scheduler_socket_name": - client_info_parsed["LocalSchedulerSocketName"] = ( - decode(value)) - elif client_info[b"client_type"] == b"local_scheduler": - # The remaining fields are resource types. - client_info_parsed[decode(field)] = float( - decode(value)) - else: - client_info_parsed[decode(field)] = decode(value) - - node_info[node_ip_address].append(client_info_parsed) - - return node_info - else: - # This is the raylet code path. - NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" - message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", - NIL_CLIENT_ID) - - # Handle the case where no clients are returned. This should only - # occur potentially immediately after the cluster is started. - if message is None: - return [] - - node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) - - # Since GCS entries are append-only, we override so that - # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = ( - ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) - - resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) - } - client_id = ray.utils.binary_to_hex(client.ClientId()) - - # If this client is being removed, then it must - # have previously been inserted, and - # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode(client.NodeManagerAddress()), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName()), - "RayletSocketName": decode(client.RayletSocketName()), - "Resources": resources - } - return list(node_info.values()) + NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" + message = self.redis_client.execute_command( + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "", + NIL_CLIENT_ID) + + # Handle the case where no clients are returned. This should only + # occur potentially immediately after the cluster is started. + if message is None: + return [] + + node_info = {} + gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + message, 0) + + # Since GCS entries are append-only, we override so that + # only the latest entries are kept. + for i in range(gcs_entry.EntriesLength()): + client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0)) + + resources = { + decode(client.ResourcesTotalLabel(i)): + client.ResourcesTotalCapacity(i) + for i in range(client.ResourcesTotalLabelLength()) + } + client_id = ray.utils.binary_to_hex(client.ClientId()) + + # If this client is being removed, then it must + # have previously been inserted, and + # it cannot have previously been removed. + if not client.IsInsertion(): + assert client_id in node_info, "Client removed not found!" + assert node_info[client_id]["IsInsertion"], ( + "Unexpected duplicate removal of client.") + + node_info[client_id] = { + "ClientID": client_id, + "IsInsertion": client.IsInsertion(), + "NodeManagerAddress": decode(client.NodeManagerAddress()), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName()), + "RayletSocketName": decode(client.RayletSocketName()), + "Resources": resources + } + return list(node_info.values()) def log_files(self): """Fetch and return a dictionary of log file names to outputs. @@ -755,10 +569,6 @@ def _profile_table(self, component_id): return profile_events def profile_table(self): - if not self.use_raylet: - raise Exception("This method is only supported in the raylet " - "code path.") - profile_table_keys = self._keys( ray.gcs_utils.TablePrefix_PROFILE_string + "*") component_identifiers_binary = [ @@ -1207,23 +1017,6 @@ def _add_missing_timestamps(self, info): info[key] = cur latest_timestamp = cur - def local_schedulers(self): - """Get a list of live local schedulers. - - Returns: - A list of the live local schedulers. - """ - if self.use_raylet: - raise Exception("The local_schedulers() method is deprecated.") - clients = self.client_table() - local_schedulers = [] - for ip_address, client_list in clients.items(): - for client in client_list: - if (client["ClientType"] == "local_scheduler" - and not client["Deleted"]): - local_schedulers.append(client) - return local_schedulers - def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") @@ -1237,8 +1030,6 @@ def workers(self): "local_scheduler_socket": (decode( worker_info[b"local_scheduler_socket"])), "node_ip_address": decode(worker_info[b"node_ip_address"]), - "plasma_manager_socket": decode( - worker_info[b"plasma_manager_socket"]), "plasma_store_socket": decode( worker_info[b"plasma_store_socket"]) } @@ -1298,24 +1089,12 @@ def cluster_resources(self): resource in the cluster. """ resources = defaultdict(int) - if not self.use_raylet: - local_schedulers = self.local_schedulers() - - for local_scheduler in local_schedulers: - for key, value in local_scheduler.items(): - if key not in [ - "ClientType", "Deleted", "DBClientID", - "AuxAddress", "LocalSchedulerSocketName" - ]: - resources[key] += value - - else: - clients = self.client_table() - for client in clients: - # Only count resources from live clients. - if client["IsInsertion"]: - for key, value in client["Resources"].items(): - resources[key] += value + clients = self.client_table() + for client in clients: + # Only count resources from live clients. + if client["IsInsertion"]: + for key, value in client["Resources"].items(): + resources[key] += value return dict(resources) @@ -1340,93 +1119,48 @@ def available_resources(self): """ available_resources_by_id = {} - if not self.use_raylet: - subscribe_client = self.redis_client.pubsub() - subscribe_client.subscribe( - ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) + subscribe_clients = [ + redis_client.pubsub(ignore_subscribe_messages=True) + for redis_client in self.redis_clients + ] + for subscribe_client in subscribe_clients: + subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) - local_scheduler_ids = { - local_scheduler["DBClientID"] - for local_scheduler in self.local_schedulers() - } + client_ids = self._live_client_ids() - while set(available_resources_by_id.keys()) != local_scheduler_ids: + while set(available_resources_by_id.keys()) != client_ids: + for subscribe_client in subscribe_clients: + # Parse client message raw_message = subscribe_client.get_message() - if raw_message is None: + if (raw_message is None or raw_message["channel"] != + ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - # Ignore subscribtion success message from Redis - # This is a long in python 2 and an int in python 3 - if isinstance(data, numbers.Number): - continue - message = (ray.gcs_utils.LocalSchedulerInfoMessage. - GetRootAsLocalSchedulerInfoMessage(data, 0)) - num_resources = message.DynamicResourcesLength() + gcs_entries = ( + ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + data, 0)) + heartbeat_data = gcs_entries.Entries(0) + message = (ray.gcs_utils.HeartbeatTableData. + GetRootAsHeartbeatTableData(heartbeat_data, 0)) + # Calculate available resources for this client + num_resources = message.ResourcesAvailableLabelLength() dynamic_resources = {} for i in range(num_resources): - dyn = message.DynamicResources(i) - resource_id = decode(dyn.Key()) - dynamic_resources[resource_id] = dyn.Value() + resource_id = decode(message.ResourcesAvailableLabel(i)) + dynamic_resources[resource_id] = ( + message.ResourcesAvailableCapacity(i)) - # Update available resources for this local scheduler - client_id = binary_to_hex(message.DbClientId()) + # Update available resources for this client + client_id = ray.utils.binary_to_hex(message.ClientId()) available_resources_by_id[client_id] = dynamic_resources - # Update local schedulers in cluster - local_scheduler_ids = { - local_scheduler["DBClientID"] - for local_scheduler in self.local_schedulers() - } - - # Remove disconnected local schedulers - for local_scheduler_id in available_resources_by_id.keys(): - if local_scheduler_id not in local_scheduler_ids: - del available_resources_by_id[local_scheduler_id] - else: - subscribe_clients = [ - redis_client.pubsub(ignore_subscribe_messages=True) - for redis_client in self.redis_clients - ] - for subscribe_client in subscribe_clients: - subscribe_client.subscribe( - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) - + # Update clients in cluster client_ids = self._live_client_ids() - while set(available_resources_by_id.keys()) != client_ids: - for subscribe_client in subscribe_clients: - # Parse client message - raw_message = subscribe_client.get_message() - if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): - continue - data = raw_message["data"] - gcs_entries = ( - ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) - # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() - dynamic_resources = {} - for i in range(num_resources): - resource_id = decode( - message.ResourcesAvailableLabel(i)) - dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) - - # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) - available_resources_by_id[client_id] = dynamic_resources - - # Update clients in cluster - client_ids = self._live_client_ids() - - # Remove disconnected clients - for client_id in available_resources_by_id.keys(): - if client_id not in client_ids: - del available_resources_by_id[client_id] + # Remove disconnected clients + for client_id in available_resources_by_id.keys(): + if client_id not in client_ids: + del available_resources_by_id[client_id] # Calculate total available resources total_available_resources = defaultdict(int) @@ -1479,10 +1213,6 @@ def error_messages(self, job_id=None): A dictionary mapping job ID to a list of the error messages for that job. """ - if not self.use_raylet: - raise Exception("The error_messages method is only supported in " - "the raylet code path.") - if job_id is not None: return self._error_messages(job_id) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 2616e064d850..bbdbe04cf7fd 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -4,19 +4,6 @@ import flatbuffers -from ray.core.generated.ResultTableReply import ResultTableReply -from ray.core.generated.SubscribeToNotificationsReply \ - import SubscribeToNotificationsReply -from ray.core.generated.TaskExecutionDependencies import \ - TaskExecutionDependencies -from ray.core.generated.TaskReply import TaskReply -from ray.core.generated.DriverTableMessage import DriverTableMessage -from ray.core.generated.LocalSchedulerInfoMessage import \ - LocalSchedulerInfoMessage -from ray.core.generated.SubscribeToDBClientTableReply import \ - SubscribeToDBClientTableReply -from ray.core.generated.TaskInfo import TaskInfo - import ray.core.generated.ErrorTableData from ray.core.generated.GcsTableEntry import GcsTableEntry @@ -32,29 +19,13 @@ from ray.core.generated.TablePubsub import TablePubsub __all__ = [ - "SubscribeToNotificationsReply", "ResultTableReply", - "TaskExecutionDependencies", "TaskReply", "DriverTableMessage", - "LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo", "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", "DriverTableData", "ProfileTableData", "ObjectTableData", "Task", "TablePrefix", "TablePubsub", "construct_error_message" ] -# These prefixes must be kept up-to-date with the definitions in -# ray_redis_module.cc. -DB_CLIENT_PREFIX = "CL:" -TASK_PREFIX = "TT:" -OBJECT_CHANNEL_PREFIX = "OC:" -OBJECT_INFO_PREFIX = "OI:" -OBJECT_LOCATION_PREFIX = "OL:" FUNCTION_PREFIX = "RemoteFunction:" -# These prefixes must be kept up-to-date with the definitions in -# common/state/redis.cc -LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers" -PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" -DRIVER_DEATH_CHANNEL = b"driver_deaths" - # xray heartbeats XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") diff --git a/python/ray/global_scheduler/__init__.py b/python/ray/global_scheduler/__init__.py deleted file mode 100644 index 25e4d2cf6490..000000000000 --- a/python/ray/global_scheduler/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from .global_scheduler_services import start_global_scheduler - -__all__ = ["start_global_scheduler"] diff --git a/python/ray/global_scheduler/build/.gitkeep b/python/ray/global_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/global_scheduler/global_scheduler_services.py b/python/ray/global_scheduler/global_scheduler_services.py deleted file mode 100644 index 7e3d019ffa98..000000000000 --- a/python/ray/global_scheduler/global_scheduler_services.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import subprocess -import time - - -def start_global_scheduler(redis_address, - node_ip_address, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a global scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will - run on. - use_valgrind (bool): True if the global scheduler should be started - inside of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the global scheduler should be started - inside a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - - Return: - The process ID of the global scheduler process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - global_scheduler_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/global_scheduler/global_scheduler") - command = [ - global_scheduler_executable, "-r", redis_address, "-h", node_ip_address - ] - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return pid diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py deleted file mode 100644 index 0e262e705c4d..000000000000 --- a/python/ray/global_scheduler/test/test.py +++ /dev/null @@ -1,332 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import os -import random -import signal -import sys -import time -import unittest - -# The ray import must come before the pyarrow import because ray modifies the -# python path so that the right version of pyarrow is found. -import ray.global_scheduler as global_scheduler -import ray.local_scheduler as local_scheduler -import ray.plasma as plasma -from ray.plasma.utils import create_object -from ray import services -from ray.experimental import state -import ray.ray_constants as ray_constants -import pyarrow as pa - -USE_VALGRIND = False -PLASMA_STORE_MEMORY = 1000000000 -NUM_CLUSTER_NODES = 2 - -NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff" -NIL_OBJECT_ID = ray_constants.ID_SIZE * b"\xff" -NIL_ACTOR_ID = ray_constants.ID_SIZE * b"\xff" - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def new_port(): - return random.randint(10000, 65535) - - -class TestGlobalScheduler(unittest.TestCase): - def setUp(self): - # Start one Redis server and N pairs of (plasma, local_scheduler) - self.node_ip_address = "127.0.0.1" - redis_address, redis_shards = services.start_redis( - self.node_ip_address, use_raylet=False) - redis_port = services.get_port(redis_address) - time.sleep(0.1) - # Create a client for the global state store. - self.state = state.GlobalState() - self.state._initialize_global_state(self.node_ip_address, redis_port) - - # Start one global scheduler. - self.p1 = global_scheduler.start_global_scheduler( - redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND) - self.plasma_store_pids = [] - self.plasma_manager_pids = [] - self.local_scheduler_pids = [] - self.plasma_clients = [] - self.local_scheduler_clients = [] - - for i in range(NUM_CLUSTER_NODES): - # Start the Plasma store. Plasma store name is randomly generated. - plasma_store_name, p2 = plasma.start_plasma_store() - self.plasma_store_pids.append(p2) - # Start the Plasma manager. - # Assumption: Plasma manager name and port are randomly generated - # by the plasma module. - manager_info = plasma.start_plasma_manager(plasma_store_name, - redis_address) - plasma_manager_name, p3, plasma_manager_port = manager_info - self.plasma_manager_pids.append(p3) - plasma_address = "{}:{}".format(self.node_ip_address, - plasma_manager_port) - plasma_client = pa.plasma.connect(plasma_store_name, - plasma_manager_name, 64) - self.plasma_clients.append(plasma_client) - # Start the local scheduler. - local_scheduler_name, p4 = local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name=plasma_manager_name, - plasma_address=plasma_address, - redis_address=redis_address, - static_resources={"CPU": 10}) - # Connect to the scheduler. - local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_WORKER_ID, False, random_task_id(), - False) - self.local_scheduler_clients.append(local_scheduler_client) - self.local_scheduler_pids.append(p4) - - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - for p2 in self.plasma_store_pids: - self.assertEqual(p2.poll(), None) - for p3 in self.plasma_manager_pids: - self.assertEqual(p3.poll(), None) - for p4 in self.local_scheduler_pids: - self.assertEqual(p4.poll(), None) - - redis_processes = services.all_processes[ - services.PROCESS_TYPE_REDIS_SERVER] - for redis_process in redis_processes: - self.assertEqual(redis_process.poll(), None) - - # Kill the global scheduler. - if USE_VALGRIND: - self.p1.send_signal(signal.SIGTERM) - self.p1.wait() - if self.p1.returncode != 0: - os._exit(-1) - else: - self.p1.kill() - # Kill local schedulers, plasma managers, and plasma stores. - for p2 in self.local_scheduler_pids: - p2.kill() - for p3 in self.plasma_manager_pids: - p3.kill() - for p4 in self.plasma_store_pids: - p4.kill() - # Kill Redis. In the event that we are using valgrind, this needs to - # happen after we kill the global scheduler. - while redis_processes: - redis_process = redis_processes.pop() - redis_process.kill() - - def get_plasma_manager_id(self): - """Get the db_client_id with client_type equal to plasma_manager. - - Iterates over all the client table keys, gets the db_client_id for the - client with client_type matching plasma_manager. Strips the client - table prefix. TODO(atumanov): write a separate function to get all - plasma manager client IDs. - - Returns: - The db_client_id if one is found and otherwise None. - """ - db_client_id = None - - client_list = self.state.client_table()[self.node_ip_address] - for client in client_list: - if client["ClientType"] == "plasma_manager": - db_client_id = client["DBClientID"] - break - - return db_client_id - - def test_task_default_resources(self): - task1 = local_scheduler.Task( - random_driver_id(), random_function_id(), [random_object_id()], 0, - random_task_id(), 0) - self.assertEqual(task1.required_resources(), {"CPU": 1}) - task2 = local_scheduler.Task( - random_driver_id(), random_function_id(), [random_object_id()], 0, - random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID), - local_scheduler.ObjectID(NIL_OBJECT_ID), - local_scheduler.ObjectID(NIL_ACTOR_ID), - local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], { - "CPU": 1, - "GPU": 2 - }) - self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2}) - - def test_redis_only_single_task(self): - # Tests global scheduler functionality by interacting with Redis and - # checking task state transitions in Redis only. TODO(atumanov): - # implement. - - # Check precondition for this test: - # There should be 2n+1 db clients: the global scheduler + one local - # scheduler and one plasma per node. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - db_client_id = self.get_plasma_manager_id() - assert (db_client_id is not None) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_single_task(self): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - - num_return_vals = [0, 1, 2, 3, 5, 10] - # Insert the object into Redis. - data_size = 0xf1f0 - metadata_size = 0x40 - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) - - # Sleep before submitting task to local scheduler. - time.sleep(0.1) - # Submit a task to Redis. - task = local_scheduler.Task( - random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - time.sleep(0.1) - # There should now be a task in Redis, and it should get assigned to - # the local scheduler - num_retries = 10 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), 1) - if len(task_entries) == 1: - task_id, task = task_entries.popitem() - task_status = task["State"] - self.assertTrue(task_status in [ - state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED - ]) - if task_status == state.TASK_STATUS_QUEUED: - break - else: - print(task_status) - print("The task has not been scheduled yet, trying again.") - num_retries -= 1 - time.sleep(1) - - if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED: - # Failed to submit and schedule a single task -- bail. - self.tearDown() - sys.exit(1) - - def integration_many_tasks_helper(self, timesync=True): - # There should be three db clients, the global scheduler, the local - # scheduler, and the plasma manager. - self.assertEqual( - len(self.state.client_table()[self.node_ip_address]), - 2 * NUM_CLUSTER_NODES + 1) - num_return_vals = [0, 1, 2, 3, 5, 10] - - # Submit a bunch of tasks to Redis. - num_tasks = 1000 - for _ in range(num_tasks): - # Create a new object for each task. - data_size = np.random.randint(1 << 12) - metadata_size = np.random.randint(1 << 9) - plasma_client = self.plasma_clients[0] - object_dep, memory_buffer, metadata = create_object( - plasma_client, data_size, metadata_size, seal=True) - if timesync: - # Give 10ms for object info handler to fire (long enough to - # yield CPU). - time.sleep(0.010) - task = local_scheduler.Task( - random_driver_id(), random_function_id(), - [local_scheduler.ObjectID(object_dep.binary())], - num_return_vals[0], random_task_id(), 0) - self.local_scheduler_clients[0].submit(task) - # Check that there are the correct number of tasks in Redis and that - # they all get assigned to the local scheduler. - num_retries = 20 - num_tasks_done = 0 - while num_retries > 0: - task_entries = self.state.task_table() - self.assertLessEqual(len(task_entries), num_tasks) - # First, check if all tasks made it to Redis. - if len(task_entries) == num_tasks: - task_statuses = [ - task_entry["State"] - for task_entry in task_entries.values() - ] - self.assertTrue( - all(status in [ - state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED, - state.TASK_STATUS_QUEUED - ] for status in task_statuses)) - num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED) - num_tasks_scheduled = task_statuses.count( - state.TASK_STATUS_SCHEDULED) - num_tasks_waiting = task_statuses.count( - state.TASK_STATUS_WAITING) - print("tasks in Redis = {}, tasks waiting = {}, " - "tasks scheduled = {}, " - "tasks queued = {}, retries left = {}".format( - len(task_entries), num_tasks_waiting, - num_tasks_scheduled, num_tasks_done, num_retries)) - if all(status == state.TASK_STATUS_QUEUED - for status in task_statuses): - # We're done, so pass. - break - num_retries -= 1 - time.sleep(0.1) - - # Tasks can either be queued or in the global scheduler due to - # spillback. - self.assertEqual(num_tasks_done + num_tasks_waiting, num_tasks) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_many_tasks_handler_sync(self): - self.integration_many_tasks_helper(timesync=True) - - @unittest.skipIf( - os.environ.get("RAY_USE_NEW_GCS", False), - "New GCS API doesn't have a Python API yet.") - def test_integration_many_tasks(self): - # More realistic case: should handle out of order object and task - # notifications. - self.integration_many_tasks_helper(timesync=False) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/internal/internal_api.py b/python/ray/internal/internal_api.py index 062d633ee44b..7772974319ae 100644 --- a/python/ray/internal/internal_api.py +++ b/python/ray/internal/internal_api.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -import ray.local_scheduler +import ray.raylet import ray.worker from ray import profiling @@ -42,7 +42,4 @@ def free(object_ids, local_only=False, worker=None): if len(object_ids) == 0: return - if worker.use_raylet: - worker.local_scheduler_client.free(object_ids, local_only) - else: - raise Exception("Free is not supported in legacy backend.") + worker.local_scheduler_client.free(object_ids, local_only) diff --git a/python/ray/local_scheduler/build/.gitkeep b/python/ray/local_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/local_scheduler/local_scheduler_services.py b/python/ray/local_scheduler/local_scheduler_services.py deleted file mode 100644 index c576014e25ce..000000000000 --- a/python/ray/local_scheduler/local_scheduler_services.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import multiprocessing -import os -import subprocess -import sys -import time - -from ray.tempfile_services import (get_local_scheduler_socket_name, - get_temp_root) - - -def start_local_scheduler(plasma_store_name, - plasma_manager_name=None, - worker_path=None, - plasma_address=None, - node_ip_address="127.0.0.1", - redis_address=None, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None, - static_resources=None, - num_workers=0): - """Start a local scheduler process. - - Args: - plasma_store_name (str): The name of the plasma store socket to connect - to. - plasma_manager_name (str): The name of the plasma manager to connect - to. This does not need to be provided, but if it is, then the Redis - address must be provided as well. - worker_path (str): The path of the worker script to use when the local - scheduler starts up new workers. - plasma_address (str): The address of the plasma manager to connect to. - This is only used by the global scheduler to figure out which - plasma managers are connected to which local schedulers. - node_ip_address (str): The address of the node that this local - scheduler is running on. - redis_address (str): The address of the Redis instance to connect to. - If this is not provided, then the local scheduler will not connect - to Redis. - use_valgrind (bool): True if the local scheduler should be started - inside of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the local scheduler should be started - inside a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - static_resources: A dictionary specifying the local scheduler's - resource capacities. This maps resource names (strings) to - integers or floats. - num_workers (int): The number of workers that the local scheduler - should start. - - Return: - A tuple of the name of the local scheduler socket and the process ID of - the local scheduler process. - """ - if (plasma_manager_name is None) != (redis_address is None): - raise Exception("If one of the plasma_manager_name and the " - "redis_address is provided, then both must be " - "provided.") - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - local_scheduler_executable = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "../core/src/local_scheduler/local_scheduler") - local_scheduler_name = get_local_scheduler_socket_name() - command = [ - local_scheduler_executable, "-s", local_scheduler_name, "-p", - plasma_store_name, "-h", node_ip_address, "-n", - str(num_workers) - ] - if plasma_manager_name is not None: - command += ["-m", plasma_manager_name] - if worker_path is not None: - assert plasma_store_name is not None - assert plasma_manager_name is not None - assert redis_address is not None - start_worker_command = ("{} {} " - "--node-ip-address={} " - "--object-store-name={} " - "--object-store-manager-name={} " - "--local-scheduler-name={} " - "--redis-address={} " - "--temp-dir={}".format( - sys.executable, worker_path, - node_ip_address, plasma_store_name, - plasma_manager_name, local_scheduler_name, - redis_address, get_temp_root())) - command += ["-w", start_worker_command] - if redis_address is not None: - command += ["-r", redis_address] - if plasma_address is not None: - command += ["-a", plasma_address] - if static_resources is not None: - resource_argument = "" - for resource_name, resource_quantity in static_resources.items(): - assert (isinstance(resource_quantity, int) - or isinstance(resource_quantity, float)) - resource_argument = ",".join([ - resource_name + "," + str(resource_quantity) - for resource_name, resource_quantity in static_resources.items() - ]) - else: - resource_argument = "CPU,{}".format(multiprocessing.cpu_count()) - command += ["-c", resource_argument] - - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return local_scheduler_name, pid diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py deleted file mode 100644 index b35d609de6e0..000000000000 --- a/python/ray/local_scheduler/test/test.py +++ /dev/null @@ -1,206 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import os -import signal -import sys -import threading -import time -import unittest - -import ray.local_scheduler as local_scheduler -import ray.plasma as plasma -import ray.ray_constants as ray_constants -import pyarrow as pa - -USE_VALGRIND = False - -NIL_WORKER_ID = ray_constants.ID_SIZE * b"\xff" - - -def random_object_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_driver_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_task_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def random_function_id(): - return local_scheduler.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -class TestLocalSchedulerClient(unittest.TestCase): - def setUp(self): - # Start Plasma store. - plasma_store_name, self.p1 = plasma.start_plasma_store() - self.plasma_client = pa.plasma.connect(plasma_store_name, "", 0) - # Start a local scheduler. - scheduler_name, self.p2 = local_scheduler.start_local_scheduler( - plasma_store_name, use_valgrind=USE_VALGRIND) - # Connect to the scheduler. - self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_WORKER_ID, False, random_task_id(), False) - - def tearDown(self): - # Check that the processes are still alive. - self.assertEqual(self.p1.poll(), None) - self.assertEqual(self.p2.poll(), None) - - # Kill Plasma. - self.p1.kill() - # Kill the local scheduler. - if USE_VALGRIND: - self.p2.send_signal(signal.SIGTERM) - self.p2.wait() - if self.p2.returncode != 0: - os._exit(-1) - else: - self.p2.kill() - - def test_submit_and_get_task(self): - function_id = random_function_id() - object_ids = [random_object_id() for i in range(256)] - # Create and seal the objects in the object store so that we can - # schedule all of the subsequent tasks. - for object_id in object_ids: - self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0) - self.plasma_client.seal(pa.plasma.ObjectID(object_id.id())) - # Define some arguments to use for the tasks. - args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], - 1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [ - 1, 1.3, 1 << 100, "hi", u"hi", [1, 2] - ], object_ids[:1], object_ids[:2], object_ids[:3], - object_ids[:4], object_ids[:5], object_ids[:10], - object_ids[:100], object_ids[:256], [1, object_ids[0]], [ - object_ids[0], "a" - ], [1, object_ids[0], "a"], [ - object_ids[0], 1, object_ids[1], "a" - ], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5], - object_ids + 100 * ["a"] + object_ids] - - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) - # Submit a task. - self.local_scheduler_client.submit(task) - # Get the task. - new_task = self.local_scheduler_client.get_task() - self.assertEqual(task.function_id().id(), - new_task.function_id().id()) - retrieved_args = new_task.arguments() - returns = new_task.returns() - self.assertEqual(len(args), len(retrieved_args)) - self.assertEqual(num_return_vals, len(returns)) - for i in range(len(retrieved_args)): - if isinstance(args[i], local_scheduler.ObjectID): - self.assertEqual(args[i].id(), retrieved_args[i].id()) - else: - self.assertEqual(args[i], retrieved_args[i]) - - # Submit all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - task = local_scheduler.Task(random_driver_id(), function_id, - args, num_return_vals, - random_task_id(), 0) - self.local_scheduler_client.submit(task) - # Get all of the tasks. - for args in args_list: - for num_return_vals in [0, 1, 2, 3, 5, 10, 100]: - new_task = self.local_scheduler_client.get_task() - - def test_scheduling_when_objects_ready(self): - # Create a task and submit it. - object_id = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id], 0, random_task_id(), 0) - self.local_scheduler_client.submit(task) - - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - - t = threading.Thread(target=get_task) - t.start() - # Sleep to give the thread time to call get_task. - time.sleep(0.1) - # Create and seal the object ID in the object store. This should - # trigger a scheduling event. - self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0) - self.plasma_client.seal(pa.plasma.ObjectID(object_id.id())) - # Wait until the thread finishes so that we know the task was - # scheduled. - t.join() - - def test_scheduling_when_objects_evicted(self): - # Create a task with two dependencies and submit it. - object_id1 = random_object_id() - object_id2 = random_object_id() - task = local_scheduler.Task(random_driver_id(), random_function_id(), - [object_id1, object_id2], 0, - random_task_id(), 0) - self.local_scheduler_client.submit(task) - - # Launch a thread to get the task. - def get_task(): - self.local_scheduler_client.get_task() - - t = threading.Thread(target=get_task) - t.start() - - # Make one of the dependencies available. - buf = self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id())) - # Release the object. - del buf - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Force eviction of the first dependency. - self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - # Check that the first object dependency was evicted. - object1 = self.plasma_client.get_buffers( - [pa.plasma.ObjectID(object_id1.id())], timeout_ms=0) - self.assertEqual(object1, [None]) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the second dependency. - self.plasma_client.create(pa.plasma.ObjectID(object_id2.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id2.id())) - # Check that the thread is still waiting for a task. - time.sleep(0.1) - self.assertTrue(t.is_alive()) - - # Create the first dependency again. Both dependencies are now - # available. - self.plasma_client.create(pa.plasma.ObjectID(object_id1.id()), 1) - self.plasma_client.seal(pa.plasma.ObjectID(object_id1.id())) - - # Wait until the thread finishes so that we know the task was - # scheduled. - t.join() - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 6212de23e694..8094e3d5ecc9 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -3,11 +3,9 @@ from __future__ import print_function import argparse -import binascii import logging import os import time -from collections import Counter, defaultdict import traceback import redis @@ -20,27 +18,6 @@ import ray.ray_constants as ray_constants from ray.services import get_ip_address, get_port from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary -from ray.worker import NIL_ACTOR_ID - -# These variables must be kept in sync with the C codebase. -# common/common.h -NIL_ID = b"\xff" * ray_constants.ID_SIZE - -# common/task.h -TASK_STATUS_LOST = 32 - -# common/redis_module/ray_redis_module.cc -OBJECT_INFO_PREFIX = b"OI:" -OBJECT_LOCATION_PREFIX = b"OL:" -TASK_TABLE_PREFIX = b"TT:" -DB_CLIENT_PREFIX = b"CL:" -DB_CLIENT_TABLE_NAME = b"db_clients" - -# local_scheduler/local_scheduler.h -LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler" - -# plasma/plasma_manager.cc -PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager" # Set up logging. logger = logging.getLogger(__name__) @@ -55,19 +32,8 @@ class Monitor(object): Attributes: redis: A connection to the Redis server. - use_raylet: A bool indicating whether to use the raylet code path or - not. subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. - dead_local_schedulers: A set of the local scheduler IDs of all of the - local schedulers that were up at one point and have died since - then. - live_plasma_managers: A counter mapping live plasma manager IDs to the - number of heartbeats that have passed since we last heard from that - plasma manager. A plasma manager is live if we received a heartbeat - from it at any point, and if it has not timed out. - dead_plasma_managers: A set of the plasma manager IDs of all the plasma - managers that were up at one point and have died since then. """ def __init__(self, @@ -79,26 +45,16 @@ def __init__(self, self.state = ray.experimental.state.GlobalState() self.state._initialize_global_state( redis_address, redis_port, redis_password=redis_password) - self.use_raylet = self.state.use_raylet self.redis = redis.StrictRedis( host=redis_address, port=redis_port, db=0, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) - if self.use_raylet: - self.shard_subscribe_clients = [] - for redis_client in self.state.redis_clients: - subscribe_client = redis_client.pubsub( - ignore_subscribe_messages=True) - self.shard_subscribe_clients.append(subscribe_client) - else: - # We don't need to subscribe to the shards in legacy Ray. - self.shard_subscribe_clients = [] - # Initialize data structures to keep track of the active database - # clients. - self.dead_local_schedulers = set() - self.live_plasma_managers = Counter() - self.dead_plasma_managers = set() + self.shard_subscribe_clients = [] + for redis_client in self.state.redis_clients: + subscribe_client = redis_client.pubsub( + ignore_subscribe_messages=True) + self.shard_subscribe_clients.append(subscribe_client) # Keep a mapping from local scheduler client ID to IP address to use # for updating the load metrics. self.local_scheduler_id_to_ip_map = {} @@ -152,170 +108,6 @@ def subscribe(self, channel, primary=True): for subscribe_client in self.shard_subscribe_clients: subscribe_client.subscribe(channel) - def cleanup_task_table(self): - """Clean up global state for failed local schedulers. - - This marks any tasks that were scheduled on dead local schedulers as - TASK_STATUS_LOST. A local scheduler is deemed dead if it is in - self.dead_local_schedulers. - """ - tasks = self.state.task_table() - num_tasks_updated = 0 - for task_id, task in tasks.items(): - # See if the corresponding local scheduler is alive. - if task["LocalSchedulerID"] not in self.dead_local_schedulers: - continue - - # Remove dummy objects returned by actor tasks from any plasma - # manager. Although the objects may still exist in that object - # store, this deletion makes them effectively unreachable by any - # local scheduler connected to a different store. - # TODO(swang): Actually remove the objects from the object store, - # so that the reconstructed actor can reuse the same object store. - if hex_to_binary(task["TaskSpec"]["ActorID"]) != NIL_ACTOR_ID: - dummy_object_id = task["TaskSpec"]["ReturnObjectIDs"][-1] - obj = self.state.object_table(dummy_object_id) - manager_ids = obj["ManagerIDs"] - if manager_ids is not None: - # The dummy object should exist on at most one plasma - # manager, the manager associated with the local scheduler - # that died. - assert len(manager_ids) <= 1 - # Remove the dummy object from the plasma manager - # associated with the dead local scheduler, if any. - for manager in manager_ids: - ok = self.state._execute_command( - dummy_object_id, "RAY.OBJECT_TABLE_REMOVE", - dummy_object_id.id(), hex_to_binary(manager)) - if ok != b"OK": - logger.warn("Failed to remove object location for " - "dead plasma manager.") - - # If the task is scheduled on a dead local scheduler, mark the - # task as lost. - key = binary_to_object_id(hex_to_binary(task_id)) - ok = self.state._execute_command( - key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), - ray.experimental.state.TASK_STATUS_LOST, NIL_ID, - task["ExecutionDependenciesString"], task["SpillbackCount"]) - if ok != b"OK": - logger.warn("Failed to update lost task for dead scheduler.") - num_tasks_updated += 1 - - if num_tasks_updated > 0: - logger.warn("Marked {} tasks as lost.".format(num_tasks_updated)) - - def cleanup_object_table(self): - """Clean up global state for failed plasma managers. - - This removes dead plasma managers from any location entries in the - object table. A plasma manager is deemed dead if it is in - self.dead_plasma_managers. - """ - # TODO(swang): Also kill the associated plasma store, since it's no - # longer reachable without a plasma manager. - objects = self.state.object_table() - num_objects_removed = 0 - for object_id, obj in objects.items(): - manager_ids = obj["ManagerIDs"] - if manager_ids is None: - continue - for manager in manager_ids: - if manager in self.dead_plasma_managers: - # If the object was on a dead plasma manager, remove that - # location entry. - ok = self.state._execute_command( - object_id, "RAY.OBJECT_TABLE_REMOVE", object_id.id(), - hex_to_binary(manager)) - if ok != b"OK": - logger.warn("Failed to remove object location for " - "dead plasma manager.") - num_objects_removed += 1 - if num_objects_removed > 0: - logger.warn("Marked {} objects as lost." - .format(num_objects_removed)) - - def scan_db_client_table(self): - """Scan the database client table for dead clients. - - After subscribing to the client table, it's necessary to call this - before reading any messages from the subscription channel. This ensures - that we do not miss any notifications for deleted clients that occurred - before we subscribed. - """ - # Exit if we are using the raylet code path because client_table is - # implemented differently. TODO(rkn): Fix this. - if self.use_raylet: - return - - clients = self.state.client_table() - for node_ip_address, node_clients in clients.items(): - for client in node_clients: - db_client_id = client["DBClientID"] - client_type = client["ClientType"] - if client["Deleted"]: - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - self.dead_plasma_managers.add(db_client_id) - - def db_client_notification_handler(self, unused_channel, data): - """Handle a notification from the db_client table from Redis. - - This handler processes notifications from the db_client table. - Notifications should be parsed using the SubscribeToDBClientTableReply - flatbuffer. Deletions are processed, insertions are ignored. Cleanup of - the associated state in the state tables should be handled by the - caller. - """ - notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply. - GetRootAsSubscribeToDBClientTableReply(data, 0)) - db_client_id = binary_to_hex(notification_object.DbClientId()) - client_type = notification_object.ClientType() - is_insertion = notification_object.IsInsertion() - - # If the update was an insertion, we ignore it. - if is_insertion: - return - - # If the update was a deletion, add them to our accounting for dead - # local schedulers and plasma managers. - logger.warn("Removed {}, client ID {}".format(client_type, - db_client_id)) - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - if db_client_id not in self.dead_local_schedulers: - self.dead_local_schedulers.add(db_client_id) - elif client_type == PLASMA_MANAGER_CLIENT_TYPE: - if db_client_id not in self.dead_plasma_managers: - self.dead_plasma_managers.add(db_client_id) - # Stop tracking this plasma manager's heartbeats, since it's - # already dead. - del self.live_plasma_managers[db_client_id] - - def local_scheduler_info_handler(self, unused_channel, data): - """Handle a local scheduler heartbeat from Redis.""" - - message = (ray.gcs_utils.LocalSchedulerInfoMessage. - GetRootAsLocalSchedulerInfoMessage(data, 0)) - num_resources = message.DynamicResourcesLength() - static_resources = {} - dynamic_resources = {} - for i in range(num_resources): - dyn = message.DynamicResources(i) - static = message.StaticResources(i) - dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value() - static_resources[static.Key().decode("utf-8")] = static.Value() - - # Update the load metrics for this local scheduler. - client_id = binascii.hexlify(message.DbClientId()).decode("utf-8") - ip = self.local_scheduler_id_to_ip_map.get(client_id) - if ip: - self.load_metrics.update(ip, static_resources, dynamic_resources) - else: - logger.warning( - "Warning: could not find ip for client {} in {}.".format( - client_id, self.local_scheduler_id_to_ip_map)) - def xray_heartbeat_handler(self, unused_channel, data): """Handle an xray heartbeat message from Redis.""" @@ -342,160 +134,6 @@ def xray_heartbeat_handler(self, unused_channel, data): print("Warning: could not find ip for client {} in {}.".format( client_id, self.local_scheduler_id_to_ip_map)) - def plasma_manager_heartbeat_handler(self, unused_channel, data): - """Handle a plasma manager heartbeat from Redis. - - This resets the number of heartbeats that we've missed from this plasma - manager. - """ - # The first ray_constants.ID_SIZE characters are the client ID. - db_client_id = data[:ray_constants.ID_SIZE] - # Reset the number of heartbeats that we've missed from this plasma - # manager. - self.live_plasma_managers[db_client_id] = 0 - - def _entries_for_driver_in_shard(self, driver_id, redis_shard_index): - """Collect IDs of control-state entries for a driver from a shard. - - Args: - driver_id: The ID of the driver. - redis_shard_index: The index of the Redis shard to query. - - Returns: - Lists of IDs: (returned_object_ids, task_ids, put_objects). The - first two are relevant to the driver and are safe to delete. - The last contains all "put" objects in this redis shard; each - element is an (object_id, corresponding task_id) pair. - """ - # TODO(zongheng): consider adding save & restore functionalities. - redis = self.state.redis_clients[redis_shard_index] - task_table_infos = {} # task id -> TaskInfo messages - - # Scan the task table & filter to get the list of tasks belong to this - # driver. Use a cursor in order not to block the redis shards. - for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"): - entry = redis.hgetall(key) - task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo( - entry[b"TaskSpec"], 0) - if driver_id != task_info.DriverId(): - # Ignore tasks that aren't from this driver. - continue - task_table_infos[task_info.TaskId()] = task_info - - # Get the list of objects returned by these tasks. Note these might - # not belong to this redis shard. - returned_object_ids = [] - for task_info in task_table_infos.values(): - returned_object_ids.extend([ - task_info.Returns(i) for i in range(task_info.ReturnsLength()) - ]) - - # Also record all the ray.put()'d objects. - put_objects = [] - for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): - entry = redis.hgetall(key) - if entry[b"is_put"] == "0": - continue - object_id = key.split(OBJECT_INFO_PREFIX)[1] - task_id = entry[b"task"] - put_objects.append((object_id, task_id)) - - return returned_object_ids, task_table_infos.keys(), put_objects - - def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index): - redis = self.state.redis_clients[shard_index] - # Clean up (in the future, save) entries for non-empty objects. - object_ids_locs = set() - object_ids_infos = set() - for object_id in object_ids: - # OL. - obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1) - if obj_loc: - object_ids_locs.add(object_id) - # OI. - obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id) - if obj_info: - object_ids_infos.add(object_id) - - # Form the redis keys to delete. - keys = [TASK_TABLE_PREFIX + k for k in task_ids] - keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs]) - keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos]) - - if not keys: - return - # Remove with best effort. - num_deleted = redis.delete(*keys) - logger.info( - "Removed {} dead redis entries of the driver from redis shard {}.". - format(num_deleted, shard_index)) - if num_deleted != len(keys): - logger.warning( - "Failed to remove {} relevant redis entries" - " from redis shard {}.".format(len(keys) - num_deleted)) - - def _clean_up_entries_for_driver(self, driver_id): - """Remove this driver's object/task entries from all redis shards. - - Specifically, removes control-state entries of: - * all objects (OI and OL entries) created by `ray.put()` from the - driver - * all tasks belonging to the driver. - """ - # TODO(zongheng): handle function_table, client_table, log_files -- - # these are in the metadata redis server, not in the shards. - driver_object_ids = [] - driver_task_ids = [] - all_put_objects = [] - - # Collect relevant ids. - # TODO(zongheng): consider parallelizing this loop. - for shard_index in range(len(self.state.redis_clients)): - returned_object_ids, task_ids, put_objects = \ - self._entries_for_driver_in_shard(driver_id, shard_index) - driver_object_ids.extend(returned_object_ids) - driver_task_ids.extend(task_ids) - all_put_objects.extend(put_objects) - - # For the put objects, keep those from relevant tasks. - driver_task_ids_set = set(driver_task_ids) - for object_id, task_id in all_put_objects: - if task_id in driver_task_ids_set: - driver_object_ids.append(object_id) - - # Partition IDs and distribute to shards. - object_ids_per_shard = defaultdict(list) - task_ids_per_shard = defaultdict(list) - - def ToShardIndex(index): - return binary_to_object_id(index).redis_shard_hash() % len( - self.state.redis_clients) - - for object_id in driver_object_ids: - object_ids_per_shard[ToShardIndex(object_id)].append(object_id) - for task_id in driver_task_ids: - task_ids_per_shard[ToShardIndex(task_id)].append(task_id) - - # TODO(zongheng): consider parallelizing this loop. - for shard_index in range(len(self.state.redis_clients)): - self._clean_up_entries_from_shard( - object_ids_per_shard[shard_index], - task_ids_per_shard[shard_index], shard_index) - - def driver_removed_handler(self, unused_channel, data): - """Handle a notification that a driver has been removed. - - This releases any GPU resources that were reserved for that driver in - Redis. - """ - message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage( - data, 0) - driver_id = message.DriverId() - logger.info("Driver {} has been removed.".format( - binary_to_hex(driver_id))) - - self._clean_up_entries_for_driver(driver_id) - def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. @@ -529,7 +167,7 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_object_id_bins = set() for object_id, object_table_object in object_table_objects.items(): assert len(object_table_object) > 0 - task_id_bin = ray.local_scheduler.compute_task_id(object_id).id() + task_id_bin = ray.raylet.compute_task_id(object_id).id() if task_id_bin in driver_task_id_bins: driver_object_id_bins.add(object_id.id()) @@ -602,20 +240,7 @@ def process_messages(self, max_messages=10000): # Determine the appropriate message handler. message_handler = None - if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL: - # The message was a heartbeat from a plasma manager. - message_handler = self.plasma_manager_heartbeat_handler - elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL: - # The message was a heartbeat from a local scheduler - message_handler = self.local_scheduler_info_handler - elif channel == DB_CLIENT_TABLE_NAME: - # The message was a notification from the db_client table. - message_handler = self.db_client_notification_handler - elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL: - # The message was a notification that a driver was removed. - logger.info("message-handler: driver_removed_handler") - message_handler = self.driver_removed_handler - elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL: + if channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL: # Similar functionality as local scheduler info channel message_handler = self.xray_heartbeat_handler elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL: @@ -629,10 +254,7 @@ def process_messages(self, max_messages=10000): message_handler(channel, data) def update_local_scheduler_map(self): - if self.use_raylet: - local_schedulers = self.state.client_table() - else: - local_schedulers = self.state.local_schedulers() + local_schedulers = self.state.client_table() self.local_scheduler_id_to_ip_map = {} for local_scheduler_info in local_schedulers: client_id = local_scheduler_info.get("DBClientID") or \ @@ -680,33 +302,11 @@ def run(self): clients and cleaning up state accordingly. """ # Initialize the subscription channel. - self.subscribe(DB_CLIENT_TABLE_NAME) - self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL) - self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL) - self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL) self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False) self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL) - # Scan the database table for dead database clients. NOTE: This must be - # called before reading any messages from the subscription channel. - # This ensures that we start in a consistent state, since we may have - # missed notifications that were sent before we connected to the - # subscription channel. - self.scan_db_client_table() - # If there were any dead clients at startup, clean up the associated - # state in the state tables. - if len(self.dead_local_schedulers) > 0: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > 0: - self.cleanup_object_table() - - num_plasma_managers = len(self.live_plasma_managers) + len( - self.dead_plasma_managers) - - logger.debug("{} dead local schedulers, {} plasma managers total, {} " - "dead plasma managers".format( - len(self.dead_local_schedulers), num_plasma_managers, - len(self.dead_plasma_managers))) + # TODO(rkn): If there were any dead clients at startup, we should clean + # up the associated state in the state tables. # Handle messages from the subscription channels. while True: @@ -720,43 +320,9 @@ def run(self): self._maybe_flush_gcs() - # Record how many dead local schedulers and plasma managers we had - # at the beginning of this round. - num_dead_local_schedulers = len(self.dead_local_schedulers) - num_dead_plasma_managers = len(self.dead_plasma_managers) - # Process a round of messages. self.process_messages() - # If any new local schedulers or plasma managers were marked as - # dead in this round, clean up the associated state. - if len(self.dead_local_schedulers) > num_dead_local_schedulers: - self.cleanup_task_table() - if len(self.dead_plasma_managers) > num_dead_plasma_managers: - self.cleanup_object_table() - - # Handle plasma managers that timed out during this round. - plasma_manager_ids = list(self.live_plasma_managers.keys()) - for plasma_manager_id in plasma_manager_ids: - if ((self.live_plasma_managers[plasma_manager_id]) >= - ray._config.num_heartbeats_timeout()): - logger.warn("Timed out {}" - .format(PLASMA_MANAGER_CLIENT_TYPE)) - # Remove the plasma manager from the managers whose - # heartbeats we're tracking. - del self.live_plasma_managers[plasma_manager_id] - # Remove the plasma manager from the db_client table. The - # corresponding state in the object table will be cleaned - # up once we receive the notification for this db_client - # deletion. - self.redis.execute_command("RAY.DISCONNECT", - plasma_manager_id) - - # Increment the number of heartbeats that we've missed from each - # plasma manager. - for plasma_manager_id in self.live_plasma_managers: - self.live_plasma_managers[plasma_manager_id] += 1 - # Wait for a heartbeat interval before processing the next round of # messages. time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3) @@ -827,6 +393,5 @@ def run(self): message = "The monitor failed with the following error:\n{}".format( traceback_str) ray.utils.push_error_to_driver_through_redis( - redis_client, monitor.use_raylet, ray_constants.MONITOR_DIED_ERROR, - message) + redis_client, ray_constants.MONITOR_DIED_ERROR, message) raise e diff --git a/python/ray/plasma/__init__.py b/python/ray/plasma/__init__.py index 1ecd0c2af2dc..6c6c18b7c555 100644 --- a/python/ray/plasma/__init__.py +++ b/python/ray/plasma/__init__.py @@ -2,9 +2,6 @@ from __future__ import division from __future__ import print_function -from ray.plasma.plasma import (start_plasma_store, start_plasma_manager, - DEFAULT_PLASMA_STORE_MEMORY) +from ray.plasma.plasma import start_plasma_store, DEFAULT_PLASMA_STORE_MEMORY -__all__ = [ - "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" -] +__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"] diff --git a/python/ray/plasma/plasma.py b/python/ray/plasma/plasma.py index 262aeebfb448..53b2434260c8 100644 --- a/python/ray/plasma/plasma.py +++ b/python/ray/plasma/plasma.py @@ -3,17 +3,13 @@ from __future__ import print_function import os -import random import subprocess import sys import time -from ray.tempfile_services import (get_object_store_socket_name, - get_plasma_manager_socket_name) +from ray.tempfile_services import get_object_store_socket_name -__all__ = [ - "start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY" -] +__all__ = ["start_plasma_store", "DEFAULT_PLASMA_STORE_MEMORY"] PLASMA_WAIT_TIMEOUT = 2**30 @@ -97,98 +93,3 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) time.sleep(0.1) return plasma_store_name, pid - - -def new_port(): - return random.randint(10000, 65535) - - -def start_plasma_manager(store_name, - redis_address, - node_ip_address="127.0.0.1", - plasma_manager_port=None, - num_retries=20, - use_valgrind=False, - run_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a plasma manager and return the ports it listens on. - - Args: - store_name (str): The name of the plasma store socket. - redis_address (str): The address of the Redis server. - node_ip_address (str): The IP address of the node. - plasma_manager_port (int): The port to use for the plasma manager. If - this is not provided, a port will be generated at random. - use_valgrind (bool): True if the Plasma manager should be started - inside of valgrind and False otherwise. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - - Returns: - A tuple of the Plasma manager socket name, the process ID of the - Plasma manager process, and the port that the manager is - listening on. - - Raises: - Exception: An exception is raised if the manager could not be started. - """ - plasma_manager_executable = os.path.join( - os.path.abspath(os.path.dirname(__file__)), - "../core/src/plasma/plasma_manager") - plasma_manager_name = get_plasma_manager_socket_name() - if plasma_manager_port is not None: - if num_retries != 1: - raise Exception("num_retries must be 1 if port is specified.") - else: - plasma_manager_port = new_port() - process = None - counter = 0 - while counter < num_retries: - if counter > 0: - print("Plasma manager failed to start, retrying now.") - command = [ - plasma_manager_executable, - "-s", - store_name, - "-m", - plasma_manager_name, - "-h", - node_ip_address, - "-p", - str(plasma_manager_port), - "-r", - redis_address, - ] - if use_valgrind: - process = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - elif run_profiler: - process = subprocess.Popen( - (["valgrind", "--tool=callgrind"] + command), - stdout=stdout_file, - stderr=stderr_file) - else: - process = subprocess.Popen( - command, stdout=stdout_file, stderr=stderr_file) - # This sleep is critical. If the plasma_manager fails to start because - # the port is already in use, then we need it to fail within 0.1 - # seconds. - if use_valgrind: - time.sleep(1) - else: - time.sleep(0.1) - # See if the process has terminated - if process.poll() is None: - return plasma_manager_name, process, plasma_manager_port - # Generate a new port and try again. - plasma_manager_port = new_port() - counter += 1 - raise Exception("Couldn't start plasma manager.") diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py deleted file mode 100644 index bc2418f005d2..000000000000 --- a/python/ray/plasma/test/test.py +++ /dev/null @@ -1,560 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from numpy.testing import assert_equal -import os -import random -import signal -import subprocess -import sys -import threading -import time -import unittest - -# The ray import must come before the pyarrow import because ray modifies the -# python path so that the right version of pyarrow is found. -import ray -from ray.plasma.utils import (random_object_id, create_object_with_id, - create_object) -import ray.ray_constants as ray_constants -from ray import services -import pyarrow as pa -import pyarrow.plasma as plasma - -USE_VALGRIND = False -PLASMA_STORE_MEMORY = 1000000000 - - -def random_name(): - return str(random.randint(0, 99999999)) - - -def assert_get_object_equal(unit_test, - client1, - client2, - object_id, - memory_buffer=None, - metadata=None): - client1_buff = client1.get_buffers([object_id])[0] - client2_buff = client2.get_buffers([object_id])[0] - client1_metadata = client1.get_metadata([object_id])[0] - client2_metadata = client2.get_metadata([object_id])[0] - unit_test.assertEqual(len(client1_buff), len(client2_buff)) - unit_test.assertEqual(len(client1_metadata), len(client2_metadata)) - # Check that the buffers from the two clients are the same. - assert_equal( - np.frombuffer(client1_buff, dtype="uint8"), - np.frombuffer(client2_buff, dtype="uint8")) - # Check that the metadata buffers from the two clients are the same. - assert_equal( - np.frombuffer(client1_metadata, dtype="uint8"), - np.frombuffer(client2_metadata, dtype="uint8")) - # If a reference buffer was provided, check that it is the same as well. - if memory_buffer is not None: - assert_equal( - np.frombuffer(memory_buffer, dtype="uint8"), - np.frombuffer(client1_buff, dtype="uint8")) - # If reference metadata was provided, check that it is the same as well. - if metadata is not None: - assert_equal( - np.frombuffer(metadata, dtype="uint8"), - np.frombuffer(client1_metadata, dtype="uint8")) - - -DEFAULT_PLASMA_STORE_MEMORY = 10**9 - - -def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, - use_valgrind=False, - use_profiler=False, - stdout_file=None, - stderr_file=None): - """Start a plasma store process. - Args: - use_valgrind (bool): True if the plasma store should be started inside - of valgrind. If this is True, use_profiler must be False. - use_profiler (bool): True if the plasma store should be started inside - a profiler. If this is True, use_valgrind must be False. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - Return: - A tuple of the name of the plasma store socket and the process ID of - the plasma store process. - """ - if use_valgrind and use_profiler: - raise Exception("Cannot use valgrind and profiler at the same time.") - plasma_store_executable = os.path.join(pa.__path__[0], - "plasma_store_server") - plasma_store_name = "/tmp/plasma_store{}".format(random_name()) - command = [ - plasma_store_executable, "-s", plasma_store_name, "-m", - str(plasma_store_memory) - ] - if use_valgrind: - pid = subprocess.Popen( - [ - "valgrind", "--track-origins=yes", "--leak-check=full", - "--show-leak-kinds=all", "--leak-check-heuristics=stdstring", - "--error-exitcode=1" - ] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - elif use_profiler: - pid = subprocess.Popen( - ["valgrind", "--tool=callgrind"] + command, - stdout=stdout_file, - stderr=stderr_file) - time.sleep(1.0) - else: - pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) - time.sleep(0.1) - return plasma_store_name, pid - - -# Plasma client tests were moved into arrow - - -class TestPlasmaManager(unittest.TestCase): - def setUp(self): - # Start two PlasmaStores. - store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND) - store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND) - # Start a Redis server. - redis_address, _ = services.start_redis("127.0.0.1", use_raylet=False) - # Start two PlasmaManagers. - manager_name1, self.p4, self.port1 = ray.plasma.start_plasma_manager( - store_name1, redis_address, use_valgrind=USE_VALGRIND) - manager_name2, self.p5, self.port2 = ray.plasma.start_plasma_manager( - store_name2, redis_address, use_valgrind=USE_VALGRIND) - # Connect two PlasmaClients. - self.client1 = plasma.connect(store_name1, manager_name1, 64) - self.client2 = plasma.connect(store_name2, manager_name2, 64) - - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: If this specific order is changed, valgrind will fail. - self.processes_to_kill = [self.p4, self.p5, self.p2, self.p3] - - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) - - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() - - # Clean up the Redis server. - services.cleanup() - - def test_fetch(self): - for _ in range(10): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - self.client1.fetch([object_id1]) - self.assertEqual(self.client1.contains(object_id1), True) - self.assertEqual(self.client2.contains(object_id1), False) - # Fetch the object from the other plasma manager. - # TODO(rkn): Right now we must wait for the object table to be - # updated. - while not self.client2.contains(object_id1): - self.client2.fetch([object_id1]) - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - - # Test that we can call fetch on object IDs that don't exist yet. - object_id2 = random_object_id() - self.client1.fetch([object_id2]) - self.assertEqual(self.client1.contains(object_id2), False) - memory_buffer2, metadata2 = create_object_with_id( - self.client2, object_id2, 2000, 2000) - # # Check that the object has been fetched. - # self.assertEqual(self.client1.contains(object_id2), True) - # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, object_id2, - # memory_buffer=memory_buffer2, - # metadata=metadata2) - - # Test calling the same fetch request a bunch of times. - object_id3 = random_object_id() - self.assertEqual(self.client1.contains(object_id3), False) - self.assertEqual(self.client2.contains(object_id3), False) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - memory_buffer3, metadata3 = create_object_with_id( - self.client1, object_id3, 2000, 2000) - for _ in range(10): - self.client1.fetch([object_id3]) - self.client2.fetch([object_id3]) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id3): - self.client2.fetch([object_id3]) - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id3, - memory_buffer=memory_buffer3, - metadata=metadata3) - - def test_fetch_multiple(self): - for _ in range(20): - # Create two objects and a third fake one that doesn't exist. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - missing_object_id = random_object_id() - object_id2, memory_buffer2, metadata2 = create_object( - self.client1, 2000, 2000) - object_ids = [object_id1, missing_object_id, object_id2] - # Fetch the objects from the other plasma store. The second object - # ID should timeout since it does not exist. - # TODO(rkn): Right now we must wait for the object table to be - # updated. - while ((not self.client2.contains(object_id1)) - or (not self.client2.contains(object_id2))): - self.client2.fetch(object_ids) - # Compare the buffers of the objects that do exist. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - # Fetch in the other direction. The fake object still does not - # exist. - self.client1.fetch(object_ids) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - - # Check that we can call fetch with duplicated object IDs. - object_id3 = random_object_id() - self.client1.fetch([object_id3, object_id3]) - object_id4, memory_buffer4, metadata4 = create_object( - self.client1, 2000, 2000) - time.sleep(0.1) - # TODO(rkn): Right now we must wait for the object table to be updated. - while not self.client2.contains(object_id4): - self.client2.fetch( - [object_id3, object_id3, object_id4, object_id4]) - assert_get_object_equal( - self, - self.client2, - self.client1, - object_id4, - memory_buffer=memory_buffer4, - metadata=metadata4) - - def test_wait(self): - # Test timeout. - obj_id0 = random_object_id() - self.client1.wait([obj_id0], timeout=100, num_returns=1) - # If we get here, the test worked. - - # Test wait if local objects available. - obj_id1 = random_object_id() - self.client1.create(obj_id1, 1000) - self.client1.seal(obj_id1) - ready, waiting = self.client1.wait( - [obj_id1], timeout=100, num_returns=1) - self.assertEqual(set(ready), {obj_id1}) - self.assertEqual(waiting, []) - - # Test wait if only one object available and only one object waited - # for. - obj_id2 = random_object_id() - self.client1.create(obj_id2, 1000) - # Don't seal. - ready, waiting = self.client1.wait( - [obj_id2, obj_id1], timeout=100, num_returns=1) - self.assertEqual(set(ready), {obj_id1}) - self.assertEqual(set(waiting), {obj_id2}) - - # Test wait if object is sealed later. - obj_id3 = random_object_id() - - def finish(): - self.client2.create(obj_id3, 1000) - self.client2.seal(obj_id3) - - t = threading.Timer(0.1, finish) - t.start() - ready, waiting = self.client1.wait( - [obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2) - self.assertEqual(set(ready), {obj_id1, obj_id3}) - self.assertEqual(set(waiting), {obj_id2}) - - # Test if the appropriate number of objects is shown if some objects - # are not ready. - ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1], 100, 3) - self.assertEqual(set(ready), {obj_id1, obj_id3}) - self.assertEqual(set(waiting), {obj_id2}) - - # Don't forget to seal obj_id2. - self.client1.seal(obj_id2) - - # Test calling wait a bunch of times. - object_ids = [] - # TODO(rkn): Increasing n to 100 (or larger) will cause failures. The - # problem appears to be that the number of timers added to the manager - # event loop slow down the manager so much that some of the - # asynchronous Redis commands timeout triggering fatal failure - # callbacks. - n = 40 - for i in range(n * (n + 1) // 2): - if i % 2 == 0: - object_id, _, _ = create_object(self.client1, 200, 200) - else: - object_id, _, _ = create_object(self.client2, 200, 200) - object_ids.append(object_id) - # Try waiting for all of the object IDs on the first client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client1.wait( - waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client1.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - # Try waiting for all of the object IDs on the second client. - waiting = object_ids - retrieved = [] - for i in range(1, n + 1): - ready, waiting = self.client2.wait( - waiting, timeout=1000, num_returns=i) - self.assertEqual(len(ready), i) - retrieved += ready - self.assertEqual(set(retrieved), set(object_ids)) - ready, waiting = self.client2.wait( - object_ids, timeout=1000, num_returns=len(object_ids)) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - # Make sure that wait returns when the requested number of object IDs - # are available and does not wait for all object IDs to be available. - object_ids = [random_object_id() for _ in range(9)] + \ - [plasma.ObjectID(ray_constants.ID_SIZE * b'\x00')] - object_ids_perm = object_ids[:] - random.shuffle(object_ids_perm) - for i in range(10): - if i % 2 == 0: - create_object_with_id(self.client1, object_ids_perm[i], 2000, - 2000) - else: - create_object_with_id(self.client2, object_ids_perm[i], 2000, - 2000) - ready, waiting = self.client1.wait(object_ids, num_returns=(i + 1)) - self.assertEqual(set(ready), set(object_ids_perm[:(i + 1)])) - self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) - - def test_transfer(self): - num_attempts = 100 - for _ in range(100): - # Create an object. - object_id1, memory_buffer1, metadata1 = create_object( - self.client1, 2000, 2000) - # Transfer the buffer to the the other Plasma store. There is a - # race condition on the create and transfer of the object, so keep - # trying until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client1.transfer("127.0.0.1", self.port2, object_id1) - buff = self.client2.get_buffers( - [object_id1], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff - - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id1, - memory_buffer=memory_buffer1, - metadata=metadata1) - # # Transfer the buffer again. - # self.client1.transfer("127.0.0.1", self.port2, object_id1) - # # Compare the two buffers. - # assert_get_object_equal(self, self.client1, self.client2, - # object_id1, - # memory_buffer=memory_buffer1, - # metadata=metadata1) - - # Create an object. - object_id2, memory_buffer2, metadata2 = create_object( - self.client2, 20000, 20000) - # Transfer the buffer to the the other Plasma store. There is a - # race condition on the create and transfer of the object, so keep - # trying until the object appears on the second Plasma store. - for i in range(num_attempts): - self.client2.transfer("127.0.0.1", self.port1, object_id2) - buff = self.client1.get_buffers( - [object_id2], timeout_ms=100)[0] - if buff is not None: - break - self.assertNotEqual(buff, None) - del buff - - # Compare the two buffers. - assert_get_object_equal( - self, - self.client1, - self.client2, - object_id2, - memory_buffer=memory_buffer2, - metadata=metadata2) - - def test_illegal_functionality(self): - # Create an object id string. - # object_id = random_object_id() - # Create a new buffer. - # memory_buffer = self.client1.create(object_id, 20000) - # This test is commented out because it currently fails. - # # Transferring the buffer before sealing it should fail. - # self.assertRaises(Exception, - # lambda : self.manager1.transfer(1, object_id)) - pass - - def test_stresstest(self): - a = time.time() - object_ids = [] - for i in range(10000): # TODO(pcm): increase this to 100000. - object_id = random_object_id() - object_ids.append(object_id) - self.client1.create(object_id, 1) - self.client1.seal(object_id) - for object_id in object_ids: - self.client1.transfer("127.0.0.1", self.port2, object_id) - b = time.time() - a - - print("it took", b, "seconds to put and transfer the objects") - - -class TestPlasmaManagerRecovery(unittest.TestCase): - def setUp(self): - # Start a Plasma store. - self.store_name, self.p2 = start_plasma_store( - use_valgrind=USE_VALGRIND) - # Start a Redis server. - self.redis_address, _ = services.start_redis( - "127.0.0.1", use_raylet=False) - # Start a PlasmaManagers. - manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - # Connect a PlasmaClient. - self.client = plasma.connect(self.store_name, manager_name, 64) - - # Store the processes that will be explicitly killed during tearDown so - # that a test case can remove ones that will be killed during the test. - # NOTE: The plasma managers must be killed before the plasma store - # since plasma store death will bring down the managers. - self.processes_to_kill = [self.p3, self.p2] - - def tearDown(self): - # Check that the processes are still alive. - for process in self.processes_to_kill: - self.assertEqual(process.poll(), None) - - # Kill the Plasma store and Plasma manager processes. - if USE_VALGRIND: - # Give processes opportunity to finish work. - time.sleep(1) - for process in self.processes_to_kill: - process.send_signal(signal.SIGTERM) - process.wait() - if process.returncode != 0: - print("aborting due to valgrind error") - os._exit(-1) - else: - for process in self.processes_to_kill: - process.kill() - - # Clean up the Redis server. - services.cleanup() - - def test_delayed_start(self): - num_objects = 10 - # Create some objects using one client. - object_ids = [random_object_id() for _ in range(num_objects)] - for i in range(10): - create_object_with_id(self.client, object_ids[i], 2000, 2000) - - # Wait until the objects have been sealed in the store. - ready, waiting = self.client.wait(object_ids, num_returns=num_objects) - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - # Start a second plasma manager attached to the same store. - manager_name, self.p5, self.port2 = ray.plasma.start_plasma_manager( - self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - self.processes_to_kill = [self.p5] + self.processes_to_kill - - # Check that the second manager knows about existing objects. - client2 = plasma.connect(self.store_name, manager_name, 64) - ready, waiting = [], object_ids - while True: - ready, waiting = client2.wait( - object_ids, num_returns=num_objects, timeout=0) - if len(ready) == len(object_ids): - break - - self.assertEqual(set(ready), set(object_ids)) - self.assertEqual(waiting, []) - - -if __name__ == "__main__": - if len(sys.argv) > 1: - # Pop the argument so we don't mess with unittest's own argument - # parser. - if sys.argv[-1] == "valgrind": - arg = sys.argv.pop() - USE_VALGRIND = True - print("Using valgrind for tests") - unittest.main(verbosity=2) diff --git a/python/ray/plasma/utils.py b/python/ray/plasma/utils.py deleted file mode 100644 index 45feb0b1db58..000000000000 --- a/python/ray/plasma/utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import random - -import pyarrow.plasma as plasma -import ray.ray_constants as ray_constants - - -def random_object_id(): - return plasma.ObjectID(np.random.bytes(ray_constants.ID_SIZE)) - - -def generate_metadata(length): - metadata_buffer = bytearray(length) - if length > 0: - metadata_buffer[0] = random.randint(0, 255) - metadata_buffer[-1] = random.randint(0, 255) - for _ in range(100): - metadata_buffer[random.randint(0, length - 1)] = (random.randint( - 0, 255)) - return metadata_buffer - - -def write_to_data_buffer(buff, length): - array = np.frombuffer(buff, dtype="uint8") - if length > 0: - array[0] = random.randint(0, 255) - array[-1] = random.randint(0, 255) - for _ in range(100): - array[random.randint(0, length - 1)] = random.randint(0, 255) - - -def create_object_with_id(client, - object_id, - data_size, - metadata_size, - seal=True): - metadata = generate_metadata(metadata_size) - memory_buffer = client.create(object_id, data_size, metadata) - write_to_data_buffer(memory_buffer, data_size) - if seal: - client.seal(object_id) - return memory_buffer, metadata - - -def create_object(client, data_size, metadata_size, seal=True): - object_id = random_object_id() - memory_buffer, metadata = create_object_with_id( - client, object_id, data_size, metadata_size, seal=seal) - return object_id, memory_buffer, metadata diff --git a/python/ray/profiling.py b/python/ray/profiling.py index a16dd9d7ad95..42b02f8926be 100644 --- a/python/ray/profiling.py +++ b/python/ray/profiling.py @@ -59,17 +59,7 @@ def profile(event_type, extra_data=None, worker=None): """ if worker is None: worker = ray.worker.global_worker - if not worker.use_raylet: - # Log the event if this is a worker and not a driver, since the - # driver's event log never gets flushed. - if worker.mode == ray.WORKER_MODE: - return RayLogSpanNonRaylet( - worker.profiler, event_type, contents=extra_data) - else: - return NULL_LOG_SPAN - else: - return RayLogSpanRaylet( - worker.profiler, event_type, extra_data=extra_data) + return RayLogSpanRaylet(worker.profiler, event_type, extra_data=extra_data) class Profiler(object): @@ -124,87 +114,20 @@ def flush_profile_data(self): events = self.events self.events = [] - if not self.worker.use_raylet: - event_log_key = b"event_log:" + self.worker.worker_id - event_log_value = json.dumps(events) - self.worker.local_scheduler_client.log_event( - event_log_key, event_log_value, time.time()) + if self.worker.mode == ray.WORKER_MODE: + component_type = "worker" else: - if self.worker.mode == ray.WORKER_MODE: - component_type = "worker" - else: - component_type = "driver" + component_type = "driver" - self.worker.local_scheduler_client.push_profile_events( - component_type, ray.ObjectID(self.worker.worker_id), - self.worker.node_ip_address, events) + self.worker.local_scheduler_client.push_profile_events( + component_type, ray.ObjectID(self.worker.worker_id), + self.worker.node_ip_address, events) def add_event(self, event): with self.lock: self.events.append(event) -class RayLogSpanNonRaylet(object): - """An object used to enable logging a span of events with a with statement. - - Attributes: - event_type (str): The type of the event being logged. - contents: Additional information to log. - """ - - def __init__(self, profiler, event_type, contents=None): - """Initialize a RayLogSpanNonRaylet object.""" - self.profiler = profiler - self.event_type = event_type - self.contents = contents - - def _log(self, event_type, kind, contents=None): - """Log an event to the global state store. - - This adds the event to a buffer of events locally. The buffer can be - flushed and written to the global state store by calling - flush_profile_data(). - - Args: - event_type (str): The type of the event. - contents: More general data to store with the event. - kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This - is LOG_POINT if the event being logged happens at a single - point in time. It is LOG_SPAN_START if we are starting to log a - span of time, and it is LOG_SPAN_END if we are finishing - logging a span of time. - """ - # TODO(rkn): This code currently takes around half a microsecond. Since - # we call it tens of times per task, this adds up. We will need to redo - # the logging code, perhaps in C. - contents = {} if contents is None else contents - assert isinstance(contents, dict) - # Make sure all of the keys and values in the dictionary are strings. - contents = {str(k): str(v) for k, v in contents.items()} - self.profiler.add_event((time.time(), event_type, kind, contents)) - - def __enter__(self): - """Log the beginning of a span event.""" - self._log( - event_type=self.event_type, - contents=self.contents, - kind=LOG_SPAN_START) - - def __exit__(self, type, value, tb): - """Log the end of a span event. Log any exception that occurred.""" - if type is None: - self._log(event_type=self.event_type, kind=LOG_SPAN_END) - else: - self._log( - event_type=self.event_type, - contents={ - "type": str(type), - "value": value, - "traceback": traceback.format_exc() - }, - kind=LOG_SPAN_END) - - class RayLogSpanRaylet(object): """An object used to enable logging a span of events with a with statement. diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index d62b57b5c1cf..a1d5e1a76543 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -5,7 +5,7 @@ import os -from ray.local_scheduler import ObjectID +from ray.raylet import ObjectID def env_integer(key, default): @@ -41,7 +41,6 @@ def env_integer(key, default): WORKER_CRASH_PUSH_ERROR = "worker_crash" WORKER_DIED_PUSH_ERROR = "worker_died" PUT_RECONSTRUCTION_PUSH_ERROR = "put_reconstruction" -HASH_MISMATCH_PUSH_ERROR = "object_hash_mismatch" INFEASIBLE_TASK_ERROR = "infeasible_task" REMOVED_NODE_ERROR = "node_removed" MONITOR_DIED_ERROR = "monitor_died" diff --git a/python/ray/local_scheduler/__init__.py b/python/ray/raylet/__init__.py similarity index 76% rename from python/ray/local_scheduler/__init__.py rename to python/ray/raylet/__init__.py index a469776f133b..8757f5974156 100644 --- a/python/ray/local_scheduler/__init__.py +++ b/python/ray/raylet/__init__.py @@ -2,10 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.core.src.local_scheduler.liblocal_scheduler_library_python import ( +from ray.core.src.ray.raylet.liblocal_scheduler_library_python import ( Task, LocalSchedulerClient, ObjectID, check_simple_value, compute_task_id, task_from_string, task_to_string, _config, common_error) -from .local_scheduler_services import start_local_scheduler __all__ = [ "Task", "LocalSchedulerClient", "ObjectID", "check_simple_value", diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index 487c3595eac5..1e19b703dba6 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -39,11 +39,8 @@ def completed_prefetch(self): for worker, obj_id in self.completed(): plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id()) - if not ray.global_state.use_raylet: - ray.worker.global_worker.plasma_client.fetch([plasma_id]) - else: - (ray.worker.global_worker.local_scheduler_client. - reconstruct_objects([obj_id], True)) + (ray.worker.global_worker.local_scheduler_client. + reconstruct_objects([obj_id], True)) self._fetching.append((worker, obj_id)) remaining = [] diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 654fba7208d7..fceaabcd2db8 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -5,7 +5,6 @@ import click import json import logging -import os import subprocess import ray.services as services @@ -20,7 +19,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_client): # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. + # "src/ray/gcs/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) # Filter to clients on the same node and do some basic checking. @@ -167,11 +166,6 @@ def cli(logging_level, logging_format): required=False, type=str, help="the file that contains the autoscaling config") -@click.option( - "--use-raylet", - default=None, - type=bool, - help="use the raylet code path, this defaults to false") @click.option( "--no-redirect-worker-output", is_flag=True, @@ -198,31 +192,15 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_max_clients, redis_password, redis_shard_ports, object_manager_port, object_store_memory, num_workers, num_cpus, num_gpus, resources, head, no_ui, block, plasma_directory, - huge_pages, autoscaling_config, use_raylet, - no_redirect_worker_output, no_redirect_output, - plasma_store_socket_name, raylet_socket_name, temp_dir): + huge_pages, autoscaling_config, no_redirect_worker_output, + no_redirect_output, plasma_store_socket_name, raylet_socket_name, + temp_dir): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) if redis_address is not None: redis_address = services.address_to_ip(redis_address) - if use_raylet is None: - if os.environ.get("RAY_USE_XRAY") == "0": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY' with " - "value {}. This turns OFF xray.".format( - os.environ.get("RAY_USE_XRAY"))) - use_raylet = False - else: - use_raylet = True - - if not use_raylet and redis_password is not None: - raise Exception("Setting the 'redis-password' argument is not " - "supported in legacy Ray. To run Ray with " - "password-protected Redis ports, pass " - "the '--use-raylet' flag.") - try: resources = json.loads(resources) except Exception: @@ -290,7 +268,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -369,7 +346,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -387,11 +363,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, @cli.command() def stop(): subprocess.call( - [ - "killall global_scheduler plasma_store_server plasma_manager " - "local_scheduler raylet raylet_monitor" - ], - shell=True) + ["killall plasma_store_server raylet raylet_monitor"], shell=True) # Find the PID of the monitor process and kill it. subprocess.call( diff --git a/python/ray/services.py b/python/ray/services.py index d57bc04291c1..d0ddd2b32f07 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -14,32 +14,25 @@ import sys import threading import time -from collections import OrderedDict, namedtuple +from collections import OrderedDict import redis import pyarrow # Ray modules import ray.ray_constants -import ray.global_scheduler as global_scheduler -import ray.local_scheduler import ray.plasma from ray.tempfile_services import ( get_ipython_notebook_path, get_logs_dir_path, get_raylet_socket_name, - get_temp_root, new_global_scheduler_log_file, new_local_scheduler_log_file, - new_log_monitor_log_file, new_monitor_log_file, - new_plasma_manager_log_file, new_plasma_store_log_file, - new_raylet_log_file, new_redis_log_file, new_webui_log_file, - new_worker_log_file, set_temp_root) + get_temp_root, new_log_monitor_log_file, new_monitor_log_file, + new_plasma_store_log_file, new_raylet_log_file, new_redis_log_file, + new_webui_log_file, set_temp_root) PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_LOG_MONITOR = "log_monitor" PROCESS_TYPE_WORKER = "worker" PROCESS_TYPE_RAYLET = "raylet" -PROCESS_TYPE_LOCAL_SCHEDULER = "local_scheduler" -PROCESS_TYPE_PLASMA_MANAGER = "plasma_manager" PROCESS_TYPE_PLASMA_STORE = "plasma_store" -PROCESS_TYPE_GLOBAL_SCHEDULER = "global_scheduler" PROCESS_TYPE_REDIS_SERVER = "redis_server" PROCESS_TYPE_WEB_UI = "web_ui" @@ -51,23 +44,20 @@ all_processes = OrderedDict( [(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []), (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []), - (PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []), - (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []), - (PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], ) + (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_REDIS_SERVER, []), + (PROCESS_TYPE_WEB_UI, [])], ) # True if processes are run in the valgrind profiler. RUN_RAYLET_PROFILER = False -RUN_LOCAL_SCHEDULER_PROFILER = False -RUN_PLASMA_MANAGER_PROFILER = False RUN_PLASMA_STORE_PROFILER = False # Location of the redis server and module. REDIS_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/thirdparty/redis/src/redis-server") + "core/src/ray/thirdparty/redis/src/redis-server") REDIS_MODULE = os.path.join( os.path.abspath(os.path.dirname(__file__)), - "core/src/common/redis_module/libray_redis_module.so") + "core/src/ray/gcs/redis_module/libray_redis_module.so") # Location of the credis server and modules. # credis will be enabled if the environment variable RAY_USE_NEW_GCS is set. @@ -88,14 +78,6 @@ RAYLET_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet") -# ObjectStoreAddress tuples contain all information necessary to connect to an -# object store. The fields are: -# - name: The socket name for the object store -# - manager_name: The socket name for the object store manager -# - manager_port: The Internet port that the object store manager listens on -ObjectStoreAddress = namedtuple("ObjectStoreAddress", - ["name", "manager_name", "manager_port"]) - # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray configures it by default automatically # using logging.basicConfig in its entry/init points. @@ -136,10 +118,7 @@ def kill_process(p): if p.poll() is not None: # The process has already terminated. return True - if any([ - RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER, - RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER - ]): + if any([RUN_RAYLET_PROFILER, RUN_PLASMA_STORE_PROFILER]): # Give process signal to write profiler data. os.kill(p.pid, signal.SIGINT) # Wait for profiling data to be written. @@ -430,7 +409,6 @@ def start_redis(node_ip_address, redis_shard_ports=None, num_redis_shards=1, redis_max_clients=None, - use_raylet=True, redirect_output=False, redirect_worker_output=False, cleanup=True, @@ -450,7 +428,6 @@ def start_redis(node_ip_address, shard. redis_max_clients: If this is provided, Ray will attempt to configure Redis with this maxclients number. - use_raylet: True if the new raylet code path should be used. redirect_output (bool): True if output should be redirected to a file and false otherwise. redirect_worker_output (bool): True if worker output should be @@ -515,12 +492,6 @@ def start_redis(node_ip_address, port = assigned_port redis_address = address(node_ip_address, port) - redis_client = redis.StrictRedis( - host=node_ip_address, port=port, password=password) - - # Store whether we're using the raylet code path or not. - redis_client.set("UseRaylet", 1 if use_raylet else 0) - # Register the number of Redis shards in the primary shard, so that clients # know how many redis shards to expect under RedisShards. primary_redis_client = redis.StrictRedis( @@ -762,40 +733,6 @@ def start_log_monitor(redis_address, password=redis_password) -def start_global_scheduler(redis_address, - node_ip_address, - stdout_file=None, - stderr_file=None, - cleanup=True, - redis_password=None): - """Start a global scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address: The IP address of the node that this scheduler will - run on. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by services.cleanup() when the - Python process that imported services exits. - redis_password (str): The password of the redis server. - """ - p = global_scheduler.start_global_scheduler( - redis_address, - node_ip_address, - stdout_file=stdout_file, - stderr_file=stderr_file) - if cleanup: - all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis( - redis_address, - node_ip_address, [stdout_file, stderr_file], - password=redis_password) - - def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): """Start a UI process. @@ -856,13 +793,11 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): return webui_url -def check_and_update_resources(resources, use_raylet): +def check_and_update_resources(resources): """Sanity check a resource dictionary and add sensible defaults. Args: resources: A dictionary mapping resource names to resource quantities. - use_raylet: True if we are using the raylet code path and false - otherwise. Returns: A new resource dictionary. @@ -901,79 +836,13 @@ def check_and_update_resources(resources, use_raylet): and not resource_quantity.is_integer()): raise ValueError("Resource quantities must all be whole numbers.") - if (use_raylet and - resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY): + if resource_quantity > ray.ray_constants.MAX_RESOURCE_QUANTITY: raise ValueError("Resource quantities must be at most {}.".format( ray.ray_constants.MAX_RESOURCE_QUANTITY)) return resources -def start_local_scheduler(redis_address, - node_ip_address, - plasma_store_name, - plasma_manager_name, - worker_path, - plasma_address=None, - stdout_file=None, - stderr_file=None, - cleanup=True, - resources=None, - num_workers=0, - redis_password=None): - """Start a local scheduler process. - - Args: - redis_address (str): The address of the Redis instance. - node_ip_address (str): The IP address of the node that this local - scheduler is running on. - plasma_store_name (str): The name of the plasma store socket to connect - to. - plasma_manager_name (str): The name of the plasma manager socket to - connect to. - worker_path (str): The path of the script to use when the local - scheduler starts up new workers. - stdout_file: A file handle opened for writing to redirect stdout to. If - no redirection should happen, then this should be None. - stderr_file: A file handle opened for writing to redirect stderr to. If - no redirection should happen, then this should be None. - cleanup (bool): True if using Ray in local mode. If cleanup is true, - then this process will be killed by serices.cleanup() when the - Python process that imported services exits. - resources: A dictionary mapping the name of a resource to the available - quantity of that resource. - num_workers (int): The number of workers that the local scheduler - should start. - redis_password (str): The password of the redis server. - - Return: - The name of the local scheduler socket. - """ - resources = check_and_update_resources(resources, False) - - logger.info("Starting local scheduler with the following resources: {}." - .format(resources)) - local_scheduler_name, p = ray.local_scheduler.start_local_scheduler( - plasma_store_name, - plasma_manager_name, - worker_path=worker_path, - node_ip_address=node_ip_address, - redis_address=redis_address, - plasma_address=plasma_address, - use_profiler=RUN_LOCAL_SCHEDULER_PROFILER, - stdout_file=stdout_file, - stderr_file=stderr_file, - static_resources=resources, - num_workers=num_workers) - if cleanup: - all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis( - redis_address, - node_ip_address, [stdout_file, stderr_file], - password=redis_password) - return local_scheduler_name - - def start_raylet(redis_address, node_ip_address, raylet_name, @@ -1017,7 +886,7 @@ def start_raylet(redis_address, if use_valgrind and use_profiler: raise Exception("Cannot use valgrind and profiler at the same time.") - static_resources = check_and_update_resources(resources, True) + static_resources = check_and_update_resources(resources) # Limit the number of workers that can be started in parallel by the # raylet. However, make sure it is at least 1. @@ -1093,13 +962,10 @@ def start_plasma_store(node_ip_address, object_manager_port=None, store_stdout_file=None, store_stderr_file=None, - manager_stdout_file=None, - manager_stderr_file=None, objstore_memory=None, cleanup=True, plasma_directory=None, huge_pages=False, - use_raylet=True, plasma_store_socket_name=None, redis_password=None): """This method starts an object store process. @@ -1114,12 +980,6 @@ def start_plasma_store(node_ip_address, to. If no redirection should happen, then this should be None. store_stderr_file: A file handle opened for writing to redirect stderr to. If no redirection should happen, then this should be None. - manager_stdout_file: A file handle opened for writing to redirect - stdout to. If no redirection should happen, then this should be - None. - manager_stderr_file: A file handle opened for writing to redirect - stderr to. If no redirection should happen, then this should be - None. objstore_memory: The amount of memory (in bytes) to start the object store with. cleanup (bool): True if using Ray in local mode. If cleanup is true, @@ -1129,12 +989,10 @@ def start_plasma_store(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. redis_password (str): The password of the redis server. Return: - A tuple of the Plasma store socket name, the Plasma manager socket - name, and the plasma manager port. + The Plasma store socket name. """ if objstore_memory is None: # Compute a fraction of the system memory for the Plasma store to use. @@ -1177,32 +1035,6 @@ def start_plasma_store(node_ip_address, plasma_directory=plasma_directory, huge_pages=huge_pages, socket_name=plasma_store_socket_name) - # Start the plasma manager. - if not use_raylet: - if object_manager_port is not None: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - plasma_manager_port=object_manager_port, - node_ip_address=node_ip_address, - num_retries=1, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - assert plasma_manager_port == object_manager_port - else: - (plasma_manager_name, p2, - plasma_manager_port) = ray.plasma.start_plasma_manager( - plasma_store_name, - redis_address, - node_ip_address=node_ip_address, - run_profiler=RUN_PLASMA_MANAGER_PROFILER, - stdout_file=manager_stdout_file, - stderr_file=manager_stderr_file) - else: - plasma_manager_port = None - plasma_manager_name = None if cleanup: all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) @@ -1210,19 +1042,12 @@ def start_plasma_store(node_ip_address, redis_address, node_ip_address, [store_stdout_file, store_stderr_file], password=redis_password) - if not use_raylet: - if cleanup: - all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) - record_log_files_in_redis(redis_address, node_ip_address, - [manager_stdout_file, manager_stderr_file]) - return ObjectStoreAddress(plasma_store_name, plasma_manager_name, - plasma_manager_port) + return plasma_store_name def start_worker(node_ip_address, object_store_name, - object_store_manager_name, local_scheduler_name, redis_address, worker_path, @@ -1235,7 +1060,6 @@ def start_worker(node_ip_address, node_ip_address (str): The IP address of the node that this worker is running on. object_store_name (str): The name of the object store. - object_store_manager_name (str): The name of the object store manager. local_scheduler_name (str): The name of the local scheduler. redis_address (str): The address that the Redis server is listening on. worker_path (str): The path of the source code which the worker process @@ -1253,7 +1077,6 @@ def start_worker(node_ip_address, sys.executable, "-u", worker_path, "--node-ip-address=" + node_ip_address, "--object-store-name=" + object_store_name, - "--object-store-manager-name=" + object_store_manager_name, "--local-scheduler-name=" + local_scheduler_name, "--redis-address=" + str(redis_address), "--temp-dir=" + get_temp_root() @@ -1349,7 +1172,6 @@ def start_ray_processes(address_info=None, cleanup=True, redirect_worker_output=False, redirect_output=False, - include_global_scheduler=False, include_log_monitor=False, include_webui=False, start_workers_from_local_scheduler=True, @@ -1357,7 +1179,6 @@ def start_ray_processes(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1398,8 +1219,6 @@ def start_ray_processes(address_info=None, processes should be redirected to files. redirect_output (bool): True if stdout and stderr for non-worker processes should be redirected to files and false otherwise. - include_global_scheduler (bool): If include_global_scheduler is True, - then start a global scheduler process. include_log_monitor (bool): If True, then start a log monitor to monitor the log files for all processes on this node and push their contents to Redis. @@ -1415,7 +1234,6 @@ def start_ray_processes(address_info=None, huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1469,7 +1287,6 @@ def start_ray_processes(address_info=None, redis_shard_ports=redis_shard_ports, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, - use_raylet=use_raylet, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, @@ -1488,13 +1305,12 @@ def start_ray_processes(address_info=None, cleanup=cleanup, autoscaling_config=autoscaling_config, redis_password=redis_password) - if use_raylet: - start_raylet_monitor( - redis_address, - stdout_file=monitor_stdout_file, - stderr_file=monitor_stderr_file, - cleanup=cleanup, - redis_password=redis_password) + start_raylet_monitor( + redis_address, + stdout_file=monitor_stdout_file, + stderr_file=monitor_stderr_file, + cleanup=cleanup, + redis_password=redis_password) if redis_shards == []: # Get redis shards from primary redis instance. redis_ip_address, redis_port = redis_address.split(":") @@ -1516,25 +1332,10 @@ def start_ray_processes(address_info=None, cleanup=cleanup, redis_password=redis_password) - # Start the global scheduler, if necessary. - if include_global_scheduler and not use_raylet: - global_scheduler_stdout_file, global_scheduler_stderr_file = ( - new_global_scheduler_log_file(redirect_output)) - start_global_scheduler( - redis_address, - node_ip_address, - stdout_file=global_scheduler_stdout_file, - stderr_file=global_scheduler_stderr_file, - cleanup=cleanup, - redis_password=redis_password) - # Initialize with existing services. if "object_store_addresses" not in address_info: address_info["object_store_addresses"] = [] object_store_addresses = address_info["object_store_addresses"] - if "local_scheduler_socket_names" not in address_info: - address_info["local_scheduler_socket_names"] = [] - local_scheduler_socket_names = address_info["local_scheduler_socket_names"] if "raylet_socket_names" not in address_info: address_info["raylet_socket_names"] = [] raylet_socket_names = address_info["raylet_socket_names"] @@ -1552,114 +1353,37 @@ def start_ray_processes(address_info=None, plasma_store_stdout_file, plasma_store_stderr_file = ( new_plasma_store_log_file(i, redirect_output)) - # If we use raylet, plasma manager won't be started and we don't need - # to create temp files for them. - plasma_manager_stdout_file, plasma_manager_stderr_file = ( - new_plasma_manager_log_file(i, redirect_output and not use_raylet)) - object_store_address = start_plasma_store( node_ip_address, redis_address, - object_manager_port=object_manager_ports[i], store_stdout_file=plasma_store_stdout_file, store_stderr_file=plasma_store_stderr_file, - manager_stdout_file=plasma_manager_stdout_file, - manager_stderr_file=plasma_manager_stderr_file, objstore_memory=object_store_memory, cleanup=cleanup, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, redis_password=redis_password) object_store_addresses.append(object_store_address) time.sleep(0.1) - if not use_raylet: - # Start any local schedulers that do not yet exist. - for i in range( - len(local_scheduler_socket_names), num_local_schedulers): - # Connect the local scheduler to the object store at the same - # index. - object_store_address = object_store_addresses[i] - plasma_address = "{}:{}".format(node_ip_address, - object_store_address.manager_port) - # Determine how many workers this local scheduler should start. - if start_workers_from_local_scheduler: - num_local_scheduler_workers = workers_per_local_scheduler[i] - workers_per_local_scheduler[i] = 0 - else: - # If we're starting the workers from Python, the local - # scheduler should not start any workers. - num_local_scheduler_workers = 0 - # Start the local scheduler. Note that if we do not wish to - # redirect the worker output, then we cannot redirect the local - # scheduler output. - local_scheduler_stdout_file, local_scheduler_stderr_file = ( - new_local_scheduler_log_file( - i, redirect_output=redirect_worker_output)) - local_scheduler_name = start_local_scheduler( + # Start any raylets that do not exist yet. + for i in range(len(raylet_socket_names), num_local_schedulers): + raylet_stdout_file, raylet_stderr_file = new_raylet_log_file( + i, redirect_output=redirect_worker_output) + address_info["raylet_socket_names"].append( + start_raylet( redis_address, node_ip_address, - object_store_address.name, - object_store_address.manager_name, + raylet_socket_name or get_raylet_socket_name(), + object_store_addresses[i], worker_path, - plasma_address=plasma_address, - stdout_file=local_scheduler_stdout_file, - stderr_file=local_scheduler_stderr_file, - cleanup=cleanup, resources=resources[i], - num_workers=num_local_scheduler_workers, - redis_password=redis_password) - local_scheduler_socket_names.append(local_scheduler_name) - - # Make sure that we have exactly num_local_schedulers instances of - # object stores and local schedulers. - assert len(object_store_addresses) == num_local_schedulers - assert len(local_scheduler_socket_names) == num_local_schedulers - - else: - # Start any raylets that do not exist yet. - for i in range(len(raylet_socket_names), num_local_schedulers): - raylet_stdout_file, raylet_stderr_file = new_raylet_log_file( - i, redirect_output=redirect_worker_output) - address_info["raylet_socket_names"].append( - start_raylet( - redis_address, - node_ip_address, - raylet_socket_name or get_raylet_socket_name(), - object_store_addresses[i].name, - worker_path, - resources=resources[i], - num_workers=workers_per_local_scheduler[i], - stdout_file=raylet_stdout_file, - stderr_file=raylet_stderr_file, - cleanup=cleanup, - redis_password=redis_password)) - - if not use_raylet: - # Start any workers that the local scheduler has not already started. - for i, num_local_scheduler_workers in enumerate( - workers_per_local_scheduler): - object_store_address = object_store_addresses[i] - local_scheduler_name = local_scheduler_socket_names[i] - for j in range(num_local_scheduler_workers): - worker_stdout_file, worker_stderr_file = new_worker_log_file( - i, j, redirect_output) - start_worker( - node_ip_address, - object_store_address.name, - object_store_address.manager_name, - local_scheduler_name, - redis_address, - worker_path, - stdout_file=worker_stdout_file, - stderr_file=worker_stderr_file, - cleanup=cleanup) - workers_per_local_scheduler[i] -= 1 - - # Make sure that we've started all the workers. - assert (sum(workers_per_local_scheduler) == 0) + num_workers=workers_per_local_scheduler[i], + stdout_file=raylet_stdout_file, + stderr_file=raylet_stderr_file, + cleanup=cleanup, + redis_password=redis_password)) # Try to start the web UI. if include_webui: @@ -1689,7 +1413,6 @@ def start_ray_node(node_ip_address, resources=None, plasma_directory=None, huge_pages=False, - use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1727,7 +1450,6 @@ def start_ray_node(node_ip_address, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. - use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1758,7 +1480,6 @@ def start_ray_node(node_ip_address, resources=resources, plasma_directory=plasma_directory, huge_pages=huge_pages, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -1784,7 +1505,6 @@ def start_ray_head(address_info=None, plasma_directory=None, huge_pages=False, autoscaling_config=None, - use_raylet=True, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1836,7 +1556,6 @@ def start_ray_head(address_info=None, huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. autoscaling_config: path to autoscaling config file. - use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1861,7 +1580,6 @@ def start_ray_head(address_info=None, cleanup=cleanup, redirect_worker_output=redirect_worker_output, redirect_output=redirect_output, - include_global_scheduler=True, include_log_monitor=True, include_webui=include_webui, start_workers_from_local_scheduler=start_workers_from_local_scheduler, @@ -1872,7 +1590,6 @@ def start_ray_head(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) diff --git a/python/ray/tempfile_services.py b/python/ray/tempfile_services.py index 76ec7c1d7ddb..bf8b6c21958f 100644 --- a/python/ray/tempfile_services.py +++ b/python/ray/tempfile_services.py @@ -117,27 +117,6 @@ def get_object_store_socket_name(): return make_inc_temp(prefix="plasma_store", directory_name=sockets_dir) -def get_plasma_manager_socket_name(): - """Get a socket name for plasma manager.""" - sockets_dir = get_sockets_dir_path() - return make_inc_temp(prefix="plasma_manager", directory_name=sockets_dir) - - -def get_local_scheduler_socket_name(suffix=""): - """Get a socket name for local scheduler. - - This function could be unsafe. The socket name may - refer to a file that did not exist at some point, but by the time - you get around to creating it, someone else may have beaten you to - the punch. - """ - sockets_dir = get_sockets_dir_path() - raylet_socket_name = make_inc_temp( - prefix="scheduler", directory_name=sockets_dir, suffix=suffix) - - return raylet_socket_name - - def get_ipython_notebook_path(port): """Get a new ipython notebook path""" @@ -211,17 +190,6 @@ def new_raylet_log_file(local_scheduler_index, redirect_output): return raylet_stdout_file, raylet_stderr_file -def new_local_scheduler_log_file(local_scheduler_index, redirect_output): - """Create new logging files for local scheduler. - - It is only used in non-raylet versions. - """ - local_scheduler_stdout_file, local_scheduler_stderr_file = (new_log_files( - "local_scheduler_{}".format(local_scheduler_index), - redirect_output=redirect_output)) - return local_scheduler_stdout_file, local_scheduler_stderr_file - - def new_webui_log_file(): """Create new logging files for web ui.""" ui_stdout_file, ui_stderr_file = new_log_files( @@ -229,17 +197,6 @@ def new_webui_log_file(): return ui_stdout_file, ui_stderr_file -def new_worker_log_file(local_scheduler_index, worker_index, redirect_output): - """Create new logging files for workers with local scheduler index. - - It is only used in non-raylet versions. - """ - worker_stdout_file, worker_stderr_file = new_log_files( - "worker_{}_{}".format(local_scheduler_index, worker_index), - redirect_output) - return worker_stdout_file, worker_stderr_file - - def new_worker_redirected_log_file(worker_id): """Create new logging files for workers to redirect its output.""" worker_stdout_file, worker_stderr_file = (new_log_files( @@ -254,16 +211,6 @@ def new_log_monitor_log_file(): return log_monitor_stdout_file, log_monitor_stderr_file -def new_global_scheduler_log_file(redirect_output): - """Create new logging files for the new global scheduler. - - It is only used in non-raylet versions. - """ - global_scheduler_stdout_file, global_scheduler_stderr_file = ( - new_log_files("global_scheduler", redirect_output)) - return global_scheduler_stdout_file, global_scheduler_stderr_file - - def new_plasma_store_log_file(local_scheduler_index, redirect_output): """Create new logging files for the plasma store.""" plasma_store_stdout_file, plasma_store_stderr_file = new_log_files( @@ -271,13 +218,6 @@ def new_plasma_store_log_file(local_scheduler_index, redirect_output): return plasma_store_stdout_file, plasma_store_stderr_file -def new_plasma_manager_log_file(local_scheduler_index, redirect_output): - """Create new logging files for the plasma manager.""" - plasma_manager_stdout_file, plasma_manager_stderr_file = new_log_files( - "plasma_manager_{}".format(local_scheduler_index), redirect_output) - return plasma_manager_stdout_file, plasma_manager_stderr_file - - def new_monitor_log_file(redirect_output): """Create new logging files for the monitor.""" monitor_stdout_file, monitor_stderr_file = new_log_files( diff --git a/python/ray/test/cluster_utils.py b/python/ray/test/cluster_utils.py index 7e7a82d67c6e..c4cf2b801623 100644 --- a/python/ray/test/cluster_utils.py +++ b/python/ray/test/cluster_utils.py @@ -44,7 +44,6 @@ def add_node(self, **override_kwargs): All nodes are by default started with the following settings: cleanup=True, - use_raylet=True, resources={"CPU": 1}, object_store_memory=100 * (2**20) # 100 MB @@ -55,12 +54,13 @@ def add_node(self, **override_kwargs): Returns: Node object of the added Ray node. """ - node_kwargs = dict( - cleanup=True, - use_raylet=True, - resources={"CPU": 1}, - object_store_memory=100 * (2**20) # 100 MB - ) + node_kwargs = { + "cleanup": True, + "resources": { + "CPU": 1 + }, + "object_store_memory": 100 * (2**20) # 100 MB + } node_kwargs.update(override_kwargs) if self.head_node is None: @@ -179,7 +179,9 @@ def kill_all_processes(self): for process_name, process_list in self.process_dict.items(): logger.info("Killing all {}(s)".format(process_name)) for process in process_list: - process.kill() + # Kill the process if it is still alive. + if process.poll() is None: + process.kill() for process_name, process_list in self.process_dict.items(): logger.info("Waiting all {}(s)".format(process_name)) diff --git a/python/ray/test/test_ray_init.py b/python/ray/test/test_ray_init.py index a64dd4a94256..3b2beaba1f65 100644 --- a/python/ray/test/test_ray_init.py +++ b/python/ray/test/test_ray_init.py @@ -28,9 +28,6 @@ class TestRedisPassword(object): @pytest.mark.skipif( os.environ.get("RAY_USE_NEW_GCS") == "on", reason="New GCS API doesn't support Redis authentication yet.") - @pytest.mark.skipif( - os.environ.get("RAY_USE_XRAY") == "0", - reason="Redis authentication is not supported in legacy Ray.") def test_redis_password(self, password, shutdown_only): # Workaround for https://github.com/ray-project/ray/issues/3045 @ray.remote diff --git a/python/ray/test/test_utils.py b/python/ray/test/test_utils.py index 6e18cd4393f6..a3614650e97b 100644 --- a/python/ray/test/test_utils.py +++ b/python/ray/test/test_utils.py @@ -35,22 +35,11 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20): client_table = ray.global_state.client_table() num_ready_nodes = len(client_table) if num_ready_nodes == num_nodes: - ready = True # Check that for each node, a local scheduler and a plasma manager # are present. - if ray.global_state.use_raylet: - # In raylet mode, this is a list of map. - # The GCS info will appear as a whole instead of part by part. - return - else: - for ip_address, clients in client_table.items(): - client_types = [client["ClientType"] for client in clients] - if "local_scheduler" not in client_types: - ready = False - if "plasma_manager" not in client_types: - ready = False - if ready: - return + # In raylet mode, this is a list of map. + # The GCS info will appear as a whole instead of part by part. + return if num_ready_nodes > num_nodes: # Too many nodes have joined. Something must be wrong. raise Exception("{} nodes have joined the cluster, but we were " diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index acbebb38b4ab..4a216a60d2be 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -213,20 +213,9 @@ def _return_resources(self, resources): assert self._committed_resources.gpu >= 0 def _update_avail_resources(self): - if ray.worker.global_worker.use_raylet: - # TODO(rliaw): Remove once raylet flag is swapped - resources = ray.global_state.cluster_resources() - num_cpus = resources["CPU"] - num_gpus = resources["GPU"] - else: - clients = ray.global_state.client_table() - local_schedulers = [ - entry for client in clients.values() for entry in client - if (entry['ClientType'] == 'local_scheduler' - and not entry['Deleted']) - ] - num_cpus = sum(ls['CPU'] for ls in local_schedulers) - num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) + resources = ray.global_state.cluster_resources() + num_cpus = resources["CPU"] + num_gpus = resources["GPU"] self._avail_resources = Resources(int(num_cpus), int(num_gpus)) self._resources_initialized = True diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 450e96136fa2..3c9ae43e6a78 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -107,7 +107,7 @@ def default_resource_request(cls, config): return Resources(cpu=config["cpu"], gpu=config["gpu"]) def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} register_trainable("B", B) @@ -440,7 +440,7 @@ def _setup(self, config): self.state = {"hi": 1} def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} def _save(self, path): return self.state @@ -471,7 +471,7 @@ def _setup(self, config): def _train(self): self.state["iter"] += 1 - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} def _save(self, path): return self.state @@ -604,7 +604,7 @@ def train(config, reporter): class B(Trainable): def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} register_trainable("f1", train) trials = run_experiments({ @@ -624,7 +624,7 @@ def _train(self): def testCheckpointAtEnd(self): class train(Trainable): def _train(self): - return dict(timesteps_this_iter=1, done=True) + return {"timesteps_this_iter": 1, "done": True} def _save(self, path): return path @@ -887,7 +887,7 @@ def testExtraResources(self): self.assertEqual(trials[1].status, Trial.PENDING) def testFractionalGpus(self): - ray.init(num_cpus=4, num_gpus=1, use_raylet=True) + ray.init(num_cpus=4, num_gpus=1) runner = TrialRunner(BasicVariantGenerator()) kwargs = { "resources": Resources(cpu=1, gpu=0.5), diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 691d25adbe97..9c047fd80043 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -28,7 +28,7 @@ def pin_in_object_store(obj): def get_pinned_object(pinned_id): """Retrieve a pinned object from the object store.""" - from ray.local_scheduler import ObjectID + from ray.raylet import ObjectID return _from_pinnable( ray.get( diff --git a/python/ray/utils.py b/python/ray/utils.py index 55f85c8ac519..83e2ae2f733f 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -15,11 +15,9 @@ import uuid import ray.gcs_utils -import ray.local_scheduler +import ray.raylet import ray.ray_constants as ray_constants -ERROR_KEY_PREFIX = b"Error:" - def _random_string(): id_hash = hashlib.sha1() @@ -70,22 +68,12 @@ def push_error_to_driver(worker, """ if driver_id is None: driver_id = ray_constants.NIL_JOB_ID.id() - error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - if not worker.use_raylet: - worker.redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) - worker.redis_client.rpush("ErrorKeys", error_key) - else: - worker.local_scheduler_client.push_error( - ray.ObjectID(driver_id), error_type, message, time.time()) + worker.local_scheduler_client.push_error( + ray.ObjectID(driver_id), error_type, message, time.time()) def push_error_to_driver_through_redis(redis_client, - use_raylet, error_type, message, driver_id=None, @@ -99,8 +87,6 @@ def push_error_to_driver_through_redis(redis_client, Args: redis_client: The redis client to use. - use_raylet: True if we are using the Raylet code path and false - otherwise. error_type (str): The type of the error. message (str): The message that will be printed in the background on the driver. @@ -111,23 +97,14 @@ def push_error_to_driver_through_redis(redis_client, """ if driver_id is None: driver_id = ray_constants.NIL_JOB_ID.id() - error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string() data = {} if data is None else data - if not use_raylet: - redis_client.hmset(error_key, { - "type": error_type, - "message": message, - "data": data - }) - redis_client.rpush("ErrorKeys", error_key) - else: - # Do everything in Python and through the Python Redis client instead - # of through the raylet. - error_data = ray.gcs_utils.construct_error_message( - driver_id, error_type, message, time.time()) - redis_client.execute_command( - "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) + # Do everything in Python and through the Python Redis client instead + # of through the raylet. + error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, + message, time.time()) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 266b995f6130..0b042dc11472 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -27,14 +27,13 @@ import ray.services as services import ray.signature import ray.tempfile_services as tempfile_services -import ray.local_scheduler +import ray.raylet import ray.plasma import ray.ray_constants as ray_constants from ray import import_thread from ray import profiling from ray.function_manager import FunctionActorManager from ray.utils import ( - binary_to_hex, check_oversized_pickle, is_cython, random_string, @@ -56,14 +55,6 @@ NIL_ACTOR_HANDLE_ID = NIL_ID NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff" -# This must be kept in sync with the `error_types` array in -# common/state/error_table.h. -OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch" -PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction" - -# This must be kept in sync with the `scheduling_state` enum in common/task.h. -TASK_STATUS_RUNNING = 8 - # Default resource requirements for actors when no resource requirements are # specified. DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1 @@ -461,13 +452,9 @@ def get_object(self, object_ids): ] for i in range(0, len(object_ids), ray._config.worker_fetch_request_size()): - if not self.use_raylet: - self.plasma_client.fetch(plain_object_ids[i:( - i + ray._config.worker_fetch_request_size())]) - else: - self.local_scheduler_client.reconstruct_objects( - object_ids[i:( - i + ray._config.worker_fetch_request_size())], True) + self.local_scheduler_client.reconstruct_objects( + object_ids[i:(i + ray._config.worker_fetch_request_size())], + True) # Get the objects. We initially try to get the objects immediately. final_results = self.retrieve_and_deserialize(plain_object_ids, 0) @@ -497,25 +484,9 @@ def get_object(self, object_ids): ray._config.worker_fetch_request_size()) for i in range(0, len(object_ids_to_fetch), fetch_request_size): - if not self.use_raylet: - for unready_id in ray_object_ids_to_fetch[i:( - i + fetch_request_size)]: - (self.local_scheduler_client. - reconstruct_objects([unready_id], False)) - # Do another fetch for objects that aren't - # available locally yet, in case they were evicted - # since the last fetch. We divide the fetch into - # smaller fetches so as to not block the manager - # for a prolonged period of time in a single call. - # This is only necessary for legacy ray since - # reconstruction and fetch are implemented by - # different processes. - self.plasma_client.fetch(object_ids_to_fetch[i:( - i + fetch_request_size)]) - else: - self.local_scheduler_client.reconstruct_objects( - ray_object_ids_to_fetch[i:( - i + fetch_request_size)], False) + self.local_scheduler_client.reconstruct_objects( + ray_object_ids_to_fetch[i:( + i + fetch_request_size)], False) results = self.retrieve_and_deserialize( object_ids_to_fetch, max([ @@ -608,7 +579,7 @@ def submit_task(self, for arg in args: if isinstance(arg, ray.ObjectID): args_for_local_scheduler.append(arg) - elif ray.local_scheduler.check_simple_value(arg): + elif ray.raylet.check_simple_value(arg): args_for_local_scheduler.append(arg) else: args_for_local_scheduler.append(put(arg)) @@ -641,14 +612,13 @@ def submit_task(self, task_index = self.task_index self.task_index += 1 # Submit the task to local scheduler. - task = ray.local_scheduler.Task( + task = ray.raylet.Task( driver_id, ray.ObjectID( function_id.id()), args_for_local_scheduler, num_return_vals, self.current_task_id, task_index, actor_creation_id, actor_creation_dummy_object_id, actor_id, - actor_handle_id, actor_counter, is_actor_checkpoint_method, - execution_dependencies, resources, placement_resources, - self.use_raylet) + actor_handle_id, actor_counter, execution_dependencies, + resources, placement_resources) self.local_scheduler_client.submit(task) return task.returns() @@ -925,26 +895,13 @@ def _wait_for_and_process_task(self, task): # good to know where the system is hanging. with self.lock: function_name = execution_info.function_name - if not self.use_raylet: - extra_data = { - "function_name": function_name, - "task_id": task.task_id().hex(), - "worker_id": binary_to_hex(self.worker_id) - } - else: - extra_data = { - "name": function_name, - "task_id": task.task_id().hex() - } + extra_data = { + "name": function_name, + "task_id": task.task_id().hex() + } with profiling.profile("task", extra_data=extra_data, worker=self): self._process_task(task, execution_info) - # In the non-raylet code path, push all of the log events to the global - # state store. In the raylet code path, this is done periodically in a - # background thread. - if not self.use_raylet: - self.profiler.flush_profile_data() - # Increase the task execution counter. self.function_actor_manager.increase_task_counter( driver_id, function_id.id()) @@ -998,13 +955,10 @@ def get_gpu_ids(): raise Exception("ray.get_gpu_ids() currently does not work in PYTHON " "MODE.") - if not global_worker.use_raylet: - assigned_ids = global_worker.local_scheduler_client.gpu_ids() - else: - all_resource_ids = global_worker.local_scheduler_client.resource_ids() - assigned_ids = [ - resource_id for resource_id, _ in all_resource_ids.get("GPU", []) - ] + all_resource_ids = global_worker.local_scheduler_client.resource_ids() + assigned_ids = [ + resource_id for resource_id, _ in all_resource_ids.get("GPU", []) + ] # If the user had already set CUDA_VISIBLE_DEVICES, then respect that (in # the sense that only GPU IDs that appear in CUDA_VISIBLE_DEVICES should be # returned). @@ -1019,17 +973,11 @@ def get_gpu_ids(): def get_resource_ids(): """Get the IDs of the resources that are available to the worker. - This function is only supported in the raylet code path. - Returns: A dictionary mapping the name of a resource to a list of pairs, where each pair consists of the ID of a resource and the fraction of that resource reserved for this worker. """ - if not global_worker.use_raylet: - raise Exception("ray.get_resource_ids() is only supported in the " - "raylet code path.") - if _mode() == LOCAL_MODE: raise Exception( "ray.get_resource_ids() currently does not work in PYTHON " @@ -1112,22 +1060,8 @@ def error_applies_to_driver(error_key, worker=global_worker): def error_info(worker=global_worker): """Return information about failed tasks.""" worker.check_connected() - if worker.use_raylet: - return (global_state.error_messages(job_id=worker.task_driver_id) + - global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) - error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) - errors = [] - for error_key in error_keys: - if error_applies_to_driver(error_key, worker=worker): - error_contents = worker.redis_client.hgetall(error_key) - error_contents = { - "type": ray.utils.decode(error_contents[b"type"]), - "message": ray.utils.decode(error_contents[b"message"]), - "data": ray.utils.decode(error_contents[b"data"]) - } - errors.append(error_contents) - - return errors + return (global_state.error_messages(job_id=worker.task_driver_id) + + global_state.error_messages(job_id=ray_constants.NIL_JOB_ID)) def _initialize_serialization(driver_id, worker=global_worker): @@ -1223,7 +1157,6 @@ def actor_handle_deserializer(serialized_obj): def get_address_info_from_redis_helper(redis_address, node_ip_address, - use_raylet=True, redis_password=None): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as @@ -1231,118 +1164,50 @@ def get_address_info_from_redis_helper(redis_address, redis_client = redis.StrictRedis( host=redis_ip_address, port=int(redis_port), password=redis_password) - if not use_raylet: - # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.cc" where it is defined. - client_keys = redis_client.keys("{}*".format( - ray.gcs_utils.DB_CLIENT_PREFIX)) - # Filter to live clients on the same node and do some basic checking. - plasma_managers = [] - local_schedulers = [] - for key in client_keys: - info = redis_client.hgetall(key) - - # Ignore clients that were deleted. - deleted = info[b"deleted"] - deleted = bool(int(deleted)) - if deleted: - continue - - assert b"ray_client_id" in info - assert b"node_ip_address" in info - assert b"client_type" in info - client_node_ip_address = ray.utils.decode(info[b"node_ip_address"]) - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - if ray.utils.decode(info[b"client_type"]) == "plasma_manager": - plasma_managers.append(info) - elif (ray.utils.decode( - info[b"client_type"]) == "local_scheduler"): - local_schedulers.append(info) - # Make sure that we got at least one plasma manager and local - # scheduler. - assert len(plasma_managers) >= 1 - assert len(local_schedulers) >= 1 - # Build the address information. - object_store_addresses = [] - for manager in plasma_managers: - address = ray.utils.decode(manager[b"manager_address"]) - port = services.get_port(address) - object_store_addresses.append( - services.ObjectStoreAddress( - name=ray.utils.decode(manager[b"store_socket_name"]), - manager_name=ray.utils.decode( - manager[b"manager_socket_name"]), - manager_port=port)) - scheduler_names = [ - ray.utils.decode(scheduler[b"local_scheduler_socket_name"]) - for scheduler in local_schedulers - ] - client_info = { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_addresses": object_store_addresses, - "local_scheduler_socket_names": scheduler_names, - # Web UI should be running. - "webui_url": _webui_url_helper(redis_client) - } - return client_info - - # Handle the raylet case. - else: - # In the raylet code path, all client data is stored in a zset at the - # key for the nil client. - client_key = b"CLIENT" + NIL_CLIENT_ID - clients = redis_client.zrange(client_key, 0, -1) - raylets = [] - for client_message in clients: - client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - client_message, 0) - client_node_ip_address = ray.utils.decode( - client.NodeManagerAddress()) - if (client_node_ip_address == node_ip_address or - (client_node_ip_address == "127.0.0.1" - and redis_ip_address == ray.services.get_node_ip_address())): - raylets.append(client) - # Make sure that at least one raylet has started locally. - # This handles a race condition where Redis has started but - # the raylet has not connected. - if len(raylets) == 0: - raise Exception( - "Redis has started but no raylets have registered yet.") - object_store_addresses = [ - services.ObjectStoreAddress( - name=ray.utils.decode(raylet.ObjectStoreSocketName()), - manager_name=None, - manager_port=None) for raylet in raylets - ] - raylet_socket_names = [ - ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets - ] - return { - "node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_addresses": object_store_addresses, - "raylet_socket_names": raylet_socket_names, - # Web UI should be running. - "webui_url": _webui_url_helper(redis_client) - } + # In the raylet code path, all client data is stored in a zset at the + # key for the nil client. + client_key = b"CLIENT" + NIL_CLIENT_ID + clients = redis_client.zrange(client_key, 0, -1) + raylets = [] + for client_message in clients: + client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + client_message, 0) + client_node_ip_address = ray.utils.decode(client.NodeManagerAddress()) + if (client_node_ip_address == node_ip_address or + (client_node_ip_address == "127.0.0.1" + and redis_ip_address == ray.services.get_node_ip_address())): + raylets.append(client) + # Make sure that at least one raylet has started locally. + # This handles a race condition where Redis has started but + # the raylet has not connected. + if len(raylets) == 0: + raise Exception( + "Redis has started but no raylets have registered yet.") + object_store_addresses = [ + ray.utils.decode(raylet.ObjectStoreSocketName()) for raylet in raylets + ] + raylet_socket_names = [ + ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets + ] + return { + "node_ip_address": node_ip_address, + "redis_address": redis_address, + "object_store_addresses": object_store_addresses, + "raylet_socket_names": raylet_socket_names, + # Web UI should be running. + "webui_url": _webui_url_helper(redis_client) + } def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - use_raylet=True, redis_password=None): counter = 0 while True: try: return get_address_info_from_redis_helper( - redis_address, - node_ip_address, - use_raylet=use_raylet, - redis_password=redis_password) + redis_address, node_ip_address, redis_password=redis_password) except Exception: if counter == num_retries: raise @@ -1414,7 +1279,6 @@ def _init(address_info=None, plasma_directory=None, huge_pages=False, include_webui=True, - use_raylet=None, plasma_store_socket_name=None, raylet_socket_name=None, temp_dir=None): @@ -1474,7 +1338,6 @@ def _init(address_info=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. - use_raylet: True if the new raylet code path should be used. plasma_store_socket_name (str): If provided, it will specify the socket name used by the plasma store. raylet_socket_name (str): If provided, it will specify the socket path @@ -1497,16 +1360,6 @@ def _init(address_info=None, else: driver_mode = SCRIPT_MODE - if use_raylet is None: - if os.environ.get("RAY_USE_XRAY") == "0": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY' with " - "value {}. This turns OFF xray.".format( - os.environ.get("RAY_USE_XRAY"))) - use_raylet = False - else: - use_raylet = True - # Get addresses of existing services. if address_info is None: address_info = {} @@ -1561,7 +1414,6 @@ def _init(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -1610,10 +1462,7 @@ def _init(address_info=None, node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. address_info = get_address_info_from_redis( - redis_address, - node_ip_address, - use_raylet=use_raylet, - redis_password=redis_password) + redis_address, node_ip_address, redis_password=redis_password) # Connect this driver to Redis, the object store, and the local scheduler. # Choose the first object store and local scheduler if there are multiple. @@ -1625,18 +1474,11 @@ def _init(address_info=None, driver_address_info = { "node_ip_address": node_ip_address, "redis_address": address_info["redis_address"], - "store_socket_name": ( - address_info["object_store_addresses"][0].name), + "store_socket_name": address_info["object_store_addresses"][0], "webui_url": address_info["webui_url"] } - if not use_raylet: - driver_address_info["manager_socket_name"] = ( - address_info["object_store_addresses"][0].manager_name) - driver_address_info["local_scheduler_socket_name"] = ( - address_info["local_scheduler_socket_names"][0]) - else: - driver_address_info["raylet_socket_name"] = ( - address_info["raylet_socket_names"][0]) + driver_address_info["raylet_socket_name"] = ( + address_info["raylet_socket_names"][0]) # We only pass `temp_dir` to a worker (WORKER_MODE). # It can't be a worker here. @@ -1645,7 +1487,6 @@ def _init(address_info=None, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, - use_raylet=use_raylet, redis_password=redis_password) return address_info @@ -1669,7 +1510,6 @@ def init(redis_address=None, plasma_directory=None, huge_pages=False, include_webui=True, - use_raylet=None, configure_logging=True, logging_level=logging.INFO, logging_format=ray_constants.LOGGER_FORMAT, @@ -1736,7 +1576,6 @@ def init(redis_address=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. - use_raylet: True if the new raylet code path should be used. configure_logging: True if allow the logging cofiguration here. Otherwise, the users may want to configure it by their own. logging_level: Logging level, default will be loging.INFO. @@ -1767,22 +1606,6 @@ def init(redis_address=None, else: raise Exception("Perhaps you called ray.init twice by accident?") - if use_raylet is None: - if os.environ.get("RAY_USE_XRAY") == "0": - # This environment variable is used in our testing setup. - logger.info("Detected environment variable 'RAY_USE_XRAY' with " - "value {}. This turns OFF xray.".format( - os.environ.get("RAY_USE_XRAY"))) - use_raylet = False - else: - use_raylet = True - - if not use_raylet and redis_password is not None: - raise Exception("Setting the 'redis_password' argument is not " - "supported in legacy Ray. To run Ray with " - "password-protected Redis ports, set " - "'use_raylet=True'.") - # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -1809,7 +1632,6 @@ def init(redis_address=None, huge_pages=huge_pages, include_webui=include_webui, object_store_memory=object_store_memory, - use_raylet=use_raylet, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir) @@ -1887,9 +1709,6 @@ def print_error_messages_raylet(worker): This runs in a separate thread on the driver and prints error messages in the background. """ - if not worker.use_raylet: - raise Exception("This function is specific to the raylet code path.") - worker.error_message_pubsub_client = worker.redis_client.pubsub( ignore_subscribe_messages=True) # Exports that are published after the call to @@ -2004,7 +1823,6 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, - use_raylet=True, redis_password=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. @@ -2015,7 +1833,6 @@ def connect(info, deterministic. mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. - use_raylet: True if the new raylet code path should be used. redis_password (str): Prevents external clients without the password from connecting to Redis if provided. """ @@ -2038,7 +1855,6 @@ def connect(info, worker.actor_id = NIL_ACTOR_ID worker.connected = True worker.set_mode(mode) - worker.use_raylet = use_raylet # If running Ray in LOCAL_MODE, there is no need to create call # create_worker or to start the worker service. @@ -2067,7 +1883,6 @@ def connect(info, traceback_str = traceback.format_exc() ray.utils.push_error_to_driver_through_redis( worker.redis_client, - worker.use_raylet, ray_constants.VERSION_MISMATCH_PUSH_ERROR, traceback_str, driver_id=None) @@ -2108,7 +1923,6 @@ def connect(info, "driver_id": worker.worker_id, "start_time": time.time(), "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info.get("manager_socket_name"), "local_scheduler_socket": info.get("local_scheduler_socket_name"), "raylet_socket": info.get("raylet_socket_name") } @@ -2123,7 +1937,6 @@ def connect(info, worker_dict = { "node_ip_address": worker.node_ip_address, "plasma_store_socket": info["store_socket_name"], - "plasma_manager_socket": info["manager_socket_name"], "local_scheduler_socket": info["local_scheduler_socket_name"] } if redirect_worker_output: @@ -2135,18 +1948,10 @@ def connect(info, raise Exception("This code should be unreachable.") # Create an object store client. - if not worker.use_raylet: - worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], - info["manager_socket_name"], 64)) - else: - worker.plasma_client = thread_safe_client( - plasma.connect(info["store_socket_name"], "", 64)) + worker.plasma_client = thread_safe_client( + plasma.connect(info["store_socket_name"], "", 64)) - if not worker.use_raylet: - local_scheduler_socket = info["local_scheduler_socket_name"] - else: - local_scheduler_socket = info["raylet_socket_name"] + local_scheduler_socket = info["raylet_socket_name"] # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -2177,28 +1982,22 @@ def connect(info, # rerun the driver. nil_actor_counter = 0 - driver_task = ray.local_scheduler.Task( - worker.task_driver_id, ray.ObjectID(NIL_FUNCTION_ID), [], 0, - worker.current_task_id, worker.task_index, - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - ray.ObjectID(NIL_ACTOR_ID), ray.ObjectID(NIL_ACTOR_ID), - nil_actor_counter, False, [], {"CPU": 0}, {}, worker.use_raylet) + driver_task = ray.raylet.Task(worker.task_driver_id, + ray.ObjectID(NIL_FUNCTION_ID), [], 0, + worker.current_task_id, + worker.task_index, + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + ray.ObjectID(NIL_ACTOR_ID), + nil_actor_counter, [], {"CPU": 0}, {}) # Add the driver task to the task table. - if not worker.use_raylet: - global_state._execute_command( - driver_task.task_id(), "RAY.TASK_TABLE_ADD", - driver_task.task_id().id(), TASK_STATUS_RUNNING, - NIL_LOCAL_SCHEDULER_ID, - driver_task.execution_dependencies_string(), 0, - ray.local_scheduler.task_to_string(driver_task)) - else: - global_state._execute_command( - driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().id(), - driver_task._serialized_raylet_task()) + global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, + driver_task.task_id().id(), + driver_task._serialized_raylet_task()) # Set the driver's current task ID to the task ID assigned to the # driver task. @@ -2207,9 +2006,9 @@ def connect(info, # A non-driver worker begins without an assigned task. worker.current_task_id = ray.ObjectID(NIL_ID) - worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( + worker.local_scheduler_client = ray.raylet.LocalSchedulerClient( local_scheduler_socket, worker.worker_id, is_worker, - worker.current_task_id, worker.use_raylet) + worker.current_task_id) # Start the import thread import_thread.ImportThread(worker, mode).start() @@ -2221,16 +2020,10 @@ def connect(info, # temporarily using this implementation which constantly queries the # scheduler for new error messages. if mode == SCRIPT_MODE: - if not worker.use_raylet: - t = threading.Thread( - target=print_error_messages, - name="ray_print_error_messages", - args=(worker, )) - else: - t = threading.Thread( - target=print_error_messages_raylet, - name="ray_print_error_messages", - args=(worker, )) + t = threading.Thread( + target=print_error_messages_raylet, + name="ray_print_error_messages", + args=(worker, )) # Making the thread a daemon causes it to exit when the main thread # exits. t.daemon = True @@ -2238,7 +2031,7 @@ def connect(info, # If we are using the raylet code path and we are not in local mode, start # a background thread to periodically flush profiling data to the GCS. - if mode != LOCAL_MODE and worker.use_raylet: + if mode != LOCAL_MODE: worker.profiler.start_flush_thread() if mode == SCRIPT_MODE: @@ -2395,6 +2188,9 @@ def register_custom_serializer(cls, # worker and not across workers. class_id = random_string() + # Make sure class_id is a string. + class_id = ray.utils.binary_to_hex(class_id) + if driver_id is None: driver_id_bytes = worker.task_driver_id.id() else: @@ -2481,7 +2277,7 @@ def put(value, worker=global_worker): # In LOCAL_MODE, ray.put is the identity operation. return value object_id = worker.local_scheduler_client.compute_put_id( - worker.current_task_id, worker.put_index, worker.use_raylet) + worker.current_task_id, worker.put_index) worker.put_object(object_id, value) worker.put_index += 1 return object_id @@ -2554,21 +2350,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): raise Exception("num_returns cannot be greater than the number " "of objects provided to ray.wait.") timeout = timeout if timeout is not None else 2**30 - if worker.use_raylet: - ready_ids, remaining_ids = worker.local_scheduler_client.wait( - object_ids, num_returns, timeout, False) - else: - object_id_strs = [ - plasma.ObjectID(object_id.id()) for object_id in object_ids - ] - ready_ids, remaining_ids = worker.plasma_client.wait( - object_id_strs, timeout, num_returns) - ready_ids = [ - ray.ObjectID(object_id.binary()) for object_id in ready_ids - ] - remaining_ids = [ - ray.ObjectID(object_id.binary()) for object_id in remaining_ids - ] + ready_ids, remaining_ids = worker.local_scheduler_client.wait( + object_ids, num_returns, timeout, False) return ready_ids, remaining_ids diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 7fe46218f653..4ec9e4d14e56 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -88,10 +88,7 @@ tempfile_services.set_temp_root(args.temp_dir) ray.worker.connect( - info, - mode=ray.WORKER_MODE, - use_raylet=(args.raylet_name is not None), - redis_password=args.redis_password) + info, mode=ray.WORKER_MODE, redis_password=args.redis_password) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/python/setup.py b/python/setup.py index 29e296a13d90..1636ead058d6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -19,13 +19,10 @@ # NOTE: The lists below must be kept in sync with ray/CMakeLists.txt. ray_files = [ - "ray/core/src/common/thirdparty/redis/src/redis-server", - "ray/core/src/common/redis_module/libray_redis_module.so", + "ray/core/src/ray/thirdparty/redis/src/redis-server", + "ray/core/src/ray/gcs/redis_module/libray_redis_module.so", "ray/core/src/plasma/plasma_store_server", - "ray/core/src/plasma/plasma_manager", - "ray/core/src/local_scheduler/local_scheduler", - "ray/core/src/local_scheduler/liblocal_scheduler_library_python.so", - "ray/core/src/global_scheduler/global_scheduler", + "ray/core/src/ray/raylet/liblocal_scheduler_library_python.so", "ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet", "ray/WebUI.ipynb" ] diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt deleted file mode 100644 index b024b4a0419f..000000000000 --- a/src/common/CMakeLists.txt +++ /dev/null @@ -1,131 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(common) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${CMAKE_CURRENT_LIST_DIR}/lib/python") -endif () - -add_subdirectory(redis_module) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -g") - -include_directories(thirdparty/ae) - -# Compile flatbuffers - -set(COMMON_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/common.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(COMMON_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/common_generated.h") - -add_custom_target(gen_common_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - -add_custom_command( - OUTPUT ${COMMON_FBS_OUTPUT_FILES} - # The --gen-object-api flag generates a C++ class MessageT for each - # flatbuffers message Message, which can be used to store deserialized - # messages in data structures. This is currently used for ObjectInfo for - # example. - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${COMMON_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - add_custom_target(gen_common_python_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - - # Generate Python bindings for the flatbuffers objects. - set(PYTHON_OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/../../python/ray/core/generated/) - add_custom_command( - TARGET gen_common_python_fbs - COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${COMMON_FBS_SRC} - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - - # Encode the fact that the ray redis module requires the autogenerated - # flatbuffer files to compile. - add_dependencies(ray_redis_module gen_common_python_fbs) - - add_dependencies(gen_common_python_fbs flatbuffers_ep) -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_custom_target(gen_common_java_fbs DEPENDS ${COMMON_FBS_OUTPUT_FILES}) - - # Generate Java bindings for the flatbuffers objects. - set(JAVA_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/java) - add_custom_command( - TARGET gen_common_java_fbs - COMMAND ${FLATBUFFERS_COMPILER} -j -o ${JAVA_OUTPUT_DIR} ${COMMON_FBS_SRC} - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" - VERBATIM) - - # Encode the fact that the ray redis module requires the autogenerated - # flatbuffer files to compile. - add_dependencies(ray_redis_module gen_common_java_fbs) - - add_dependencies(gen_common_java_fbs flatbuffers_ep) -endif() - -add_custom_target( - hiredis - COMMAND make - WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis) - -add_library(common STATIC - event_loop.cc - common.cc - common_protocol.cc - task.cc - io.cc - net.cc - logging.cc - state/redis.cc - state/table.cc - state/object_table.cc - state/task_table.cc - state/db_client_table.cc - state/driver_table.cc - state/actor_notification_table.cc - state/local_scheduler_table.cc - state/error_table.cc - thirdparty/ae/ae.c - thirdparty/sha256.c) - -add_dependencies(common arrow) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - add_dependencies(common gen_common_python_fbs) -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_dependencies(common gen_common_java_fbs) -endif() - -target_link_libraries(common "${CMAKE_CURRENT_LIST_DIR}/thirdparty/hiredis/libhiredis.a") - -function(define_test test_name library) - add_executable(${test_name} test/${test_name}.cc ${ARGN}) - add_dependencies(${test_name} hiredis flatbuffers_ep) - target_link_libraries(${test_name} common ${FLATBUFFERS_STATIC_LIB} ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${library} -lpthread) - target_compile_options(${test_name} PUBLIC "-DPLASMA_TEST -DLOCAL_SCHEDULER_TEST -DCOMMON_TEST -DRAY_COMMON_LOG_LEVEL=4") -endfunction() - -define_test(db_tests "") -define_test(io_tests "") -define_test(task_tests "") -define_test(redis_tests "") -define_test(task_table_tests "") -define_test(object_table_tests "") - -add_custom_target(copy_redis ALL) -foreach(file "redis-cli" "redis-server") -add_custom_command(TARGET copy_redis POST_BUILD - COMMAND ${CMAKE_COMMAND} -E - copy ${CMAKE_CURRENT_LIST_DIR}/../../thirdparty/pkg/redis/src/${file} - ${CMAKE_BINARY_DIR}/src/common/thirdparty/redis/src/${file}) -endforeach() diff --git a/src/common/common.cc b/src/common/common.cc deleted file mode 100644 index 0a6da6a2936e..000000000000 --- a/src/common/common.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "common.h" - -#include -#include -#include -#include -#include -#include - -#include "io.h" -#include - -const unsigned char NIL_DIGEST[DIGEST_SIZE] = {0}; - -int64_t current_time_ms() { - std::chrono::milliseconds ms_since_epoch = - std::chrono::duration_cast( - std::chrono::steady_clock::now().time_since_epoch()); - return ms_since_epoch.count(); -} diff --git a/src/common/common.h b/src/common/common.h deleted file mode 100644 index f95bfcca5d26..000000000000 --- a/src/common/common.h +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef COMMON_H -#define COMMON_H - -#include -#include -#include -#ifndef __STDC_FORMAT_MACROS -#define __STDC_FORMAT_MACROS -#endif -#include -#include -#ifndef _WIN32 -#include -#endif - -#ifdef __cplusplus -#include -extern "C" { -#endif -#include "sha256.h" -#ifdef __cplusplus -} -#endif - -#include "arrow/util/macros.h" -#include "plasma/common.h" -#include "ray/id.h" -#include "ray/util/logging.h" - -#include "state/ray_config.h" - -/** Definitions for Ray logging levels. */ -#define RAY_COMMON_DEBUG 0 -#define RAY_COMMON_INFO 1 -#define RAY_COMMON_WARNING 2 -#define RAY_COMMON_ERROR 3 -#define RAY_COMMON_FATAL 4 - -/** - * RAY_COMMON_LOG_LEVEL should be defined to one of the above logging level - * integer values. Any logging statement in the code with a logging level - * greater than or equal to RAY_COMMON_LOG_LEVEL will be outputted to stderr. - * The default logging level is INFO. */ -#ifndef RAY_COMMON_LOG_LEVEL -#define RAY_COMMON_LOG_LEVEL RAY_COMMON_INFO -#endif - -/* These are exit codes for common errors that can occur in Ray components. */ -#define EXIT_COULD_NOT_BIND_PORT -2 - -/** This macro indicates that this pointer owns the data it is pointing to - * and is responsible for freeing it. */ -#define OWNER - -/** The worker ID is the ID of a worker or driver. */ -typedef ray::UniqueID WorkerID; - -typedef ray::UniqueID DBClientID; - -#define MAX(x, y) ((x) >= (y) ? (x) : (y)) -#define MIN(x, y) ((x) <= (y) ? (x) : (y)) - -/** Definitions for computing hash digests. */ -#define DIGEST_SIZE SHA256_BLOCK_SIZE - -extern const unsigned char NIL_DIGEST[DIGEST_SIZE]; - -/** - * Return the current time in milliseconds since the Unix epoch. - * - * @return The number of milliseconds since the Unix epoch. - */ -int64_t current_time_ms(); - -#endif diff --git a/src/common/doc/tasks.md b/src/common/doc/tasks.md deleted file mode 100644 index 4431afae2ee9..000000000000 --- a/src/common/doc/tasks.md +++ /dev/null @@ -1,32 +0,0 @@ -# Task specifications, task instances and task logs - -A *task specification* contains all information that is needed for computing -the results of a task: - -- The ID of the task -- The function ID of the function that executes the task -- The arguments (either object IDs for pass by reference -or values for pass by value) -- The IDs of the result objects - -From these, a task ID can be computed which is also stored in the task -specification. - -A *task* represents the execution of a task specification. -It consists of: - -- A scheduling state (WAITING, SCHEDULED, RUNNING, DONE) -- The target node where the task is scheduled or executed -- The task specification - -The task data structures are defined in `common/task.h`. - -The *task table* is a mapping from the task ID to the *task* information. It is -updated by various parts of the system: - -1. The local scheduler writes it with status WAITING when submits a task to the global scheduler -2. The global scheduler appends an update WAITING -> SCHEDULED together with the node ID when assigning the task to a local scheduler -3. The local scheduler appends an update SCHEDULED -> RUNNING when it assigns a task to a worker -4. The local scheduler appends an update RUNNING -> DONE when the task finishes execution - -The task table is defined in `common/state/task_table.h`. diff --git a/src/common/event_loop.cc b/src/common/event_loop.cc deleted file mode 100644 index e3d9cc4a2dc6..000000000000 --- a/src/common/event_loop.cc +++ /dev/null @@ -1,63 +0,0 @@ -#include "event_loop.h" - -#include "common.h" -#include - -#define INITIAL_EVENT_LOOP_SIZE 1024 - -event_loop *event_loop_create(void) { - return aeCreateEventLoop(INITIAL_EVENT_LOOP_SIZE); -} - -void event_loop_destroy(event_loop *loop) { - /* Clean up timer events. This is to make valgrind happy. */ - aeTimeEvent *te = loop->timeEventHead; - while (te) { - aeTimeEvent *next = te->next; - free(te); - te = next; - } - aeDeleteEventLoop(loop); -} - -bool event_loop_add_file(event_loop *loop, - int fd, - int events, - event_loop_file_handler handler, - void *context) { - /* Try to add the file descriptor. */ - int err = aeCreateFileEvent(loop, fd, events, handler, context); - /* If it cannot be added, increase the size of the event loop. */ - if (err == AE_ERR && errno == ERANGE) { - err = aeResizeSetSize(loop, 3 * aeGetSetSize(loop) / 2); - if (err != AE_OK) { - return false; - } - err = aeCreateFileEvent(loop, fd, events, handler, context); - } - /* In any case, test if there were errors. */ - return (err == AE_OK); -} - -void event_loop_remove_file(event_loop *loop, int fd) { - aeDeleteFileEvent(loop, fd, EVENT_LOOP_READ | EVENT_LOOP_WRITE); -} - -int64_t event_loop_add_timer(event_loop *loop, - int64_t timeout, - event_loop_timer_handler handler, - void *context) { - return aeCreateTimeEvent(loop, timeout, handler, context, NULL); -} - -int event_loop_remove_timer(event_loop *loop, int64_t id) { - return aeDeleteTimeEvent(loop, id); -} - -void event_loop_run(event_loop *loop) { - aeMain(loop); -} - -void event_loop_stop(event_loop *loop) { - aeStop(loop); -} diff --git a/src/common/event_loop.h b/src/common/event_loop.h deleted file mode 100644 index e489ab4fb672..000000000000 --- a/src/common/event_loop.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef EVENT_LOOP_H -#define EVENT_LOOP_H - -#include - -extern "C" { -#ifdef _WIN32 -/* Quirks mean that Windows version needs to be included differently */ -#include -#include -#else -#include "ae/ae.h" -#endif -} - -/* Unique timer ID that will be generated when the timer is added to the - * event loop. Will not be reused later on in another call - * to event_loop_add_timer. */ -typedef long long timer_id; - -typedef aeEventLoop event_loop; - -/* File descriptor is readable. */ -#define EVENT_LOOP_READ AE_READABLE - -/* File descriptor is writable. */ -#define EVENT_LOOP_WRITE AE_WRITABLE - -/* Constant specifying that the timer is done and it will be removed. */ -#define EVENT_LOOP_TIMER_DONE AE_NOMORE - -/* Signature of the handler that will be called when there is a new event - * on the file descriptor that this handler has been registered for. The - * context is the one that was passed into add_file by the user. The - * events parameter indicates which event is available on the file, - * it can be EVENT_LOOP_READ or EVENT_LOOP_WRITE. */ -typedef void (*event_loop_file_handler)(event_loop *loop, - int fd, - void *context, - int events); - -/* This handler will be called when a timer times out. The id of the timer - * as well as the context that was specified when registering this handler - * are passed as arguments. The return is the number of milliseconds the - * timer shall be reset to or EVENT_LOOP_TIMER_DONE if the timer shall - * not be triggered again. */ -typedef int (*event_loop_timer_handler)(event_loop *loop, - timer_id timer_id, - void *context); - -/* Create and return a new event loop. */ -event_loop *event_loop_create(void); - -/* Deallocate space associated with the event loop that was created - * with the "create" function. */ -void event_loop_destroy(event_loop *loop); - -/* Register a handler that will be called any time a new event happens on - * a file descriptor. Can specify a context that will be passed as an - * argument to the handler. Currently there can only be one handler per file. - * The events parameter specifies which events we listen to: EVENT_LOOP_READ - * or EVENT_LOOP_WRITE. */ -bool event_loop_add_file(event_loop *loop, - int fd, - int events, - event_loop_file_handler handler, - void *context); - -/* Remove a registered file event handler from the event loop. */ -void event_loop_remove_file(event_loop *loop, int fd); - -/** Register a handler that will be called after a time slice of - * "timeout" milliseconds. - * - * @param loop The event loop. - * @param timeout The timeout in milliseconds. - * @param handler The handler for the timeout. - * @param context User context that can be passed in and will be passed in - * as an argument for the timer handler. - * @return The ID of the timer. - */ -int64_t event_loop_add_timer(event_loop *loop, - int64_t timeout, - event_loop_timer_handler handler, - void *context); - -/** - * Remove a registered time event handler from the event loop. Can be called - * multiple times on the same timer. - * - * @param loop The event loop. - * @param timer_id The ID of the timer to be removed. - * @return Returns 0 if the removal was successful. - */ -int event_loop_remove_timer(event_loop *loop, int64_t timer_id); - -/* Run the event loop. */ -void event_loop_run(event_loop *loop); - -/* Stop the event loop. */ -void event_loop_stop(event_loop *loop); - -#endif diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs deleted file mode 100644 index a5b2177f1c30..000000000000 --- a/src/common/format/common.fbs +++ /dev/null @@ -1,203 +0,0 @@ - -// Indices into resource vectors. -// A resource vector maps a resource index to the number -// of units of that resource required. - -table Arg { - // Object ID for pass-by-reference arguments. Normally there is only one - // object ID in this list which represents the object that is being passed. - // However to support reducers in a MapReduce workload, we also support - // passing multiple object IDs for each argument. - object_ids: [string]; - // Data for pass-by-value arguments. - data: string; -} - -table ResourcePair { - // The name of the resource. - key: string; - // The quantity of the resource. - value: double; -} - -// NOTE: This enum is duplicate with the `Language` enum in `gcs.fbs`, -// because we cannot include this file in `gcs.fbs` due to cyclic dependency. -// TODO(raulchen): remove it once we get rid of legacy ray. -enum TaskLanguage:int { - PYTHON = 0, - JAVA = 1 -} - -table TaskInfo { - // ID of the driver that created this task. - driver_id: string; - // Task ID of the task. - task_id: string; - // Task ID of the parent task. - parent_task_id: string; - // A count of the number of tasks submitted by the parent task before this one. - parent_counter: int; - // The ID of the actor to create if this is an actor creation task. - actor_creation_id: string; - // The dummy object ID of the actor creation task if this is an actor method. - actor_creation_dummy_object_id: string; - // Actor ID of the task. This is the actor that this task is executed on - // or NIL_ACTOR_ID if the task is just a normal task. - actor_id: string; - // The ID of the handle that was used to submit the task. This should be - // unique across handles with the same actor_id. - actor_handle_id: string; - // Number of tasks that have been submitted to this actor so far. - actor_counter: int; - // True if this task is an actor checkpoint task and false otherwise. - is_actor_checkpoint_method: bool; - // Function ID of the task. - function_id: string; - // Task arguments. - args: [Arg]; - // Object IDs of return values. - returns: [string]; - // The required_resources vector indicates the quantities of the different - // resources required by this task. - required_resources: [ResourcePair]; - // The resources required for placing this task on a node. If this is empty, - // then the placement resources are equal to the required_resources. - required_placement_resources: [ResourcePair]; - // The language that this task belongs to - language: TaskLanguage; - // Function descriptor, which is a list of strings that can - // uniquely describe a function. - // For a Python function, it should be: [module_name, class_name, function_name] - // For a Java function, it should be: [class_name, method_name, type_descriptor] - // TODO(hchen): after changing Python worker to use function_descriptor, - // function_id can be removed. - function_descriptor: [string]; -} - -// Object information data structure. -// NOTE(pcm): This structure is replicated in -// https://github.com/apache/arrow/blob/master/cpp/src/plasma/format/common.fbs, -// so if you modify it, you should also modify that one. -table ObjectInfo { - // Object ID of this object. - object_id: string; - // Number of bytes the content of this object occupies in memory. - data_size: long; - // Number of bytes the metadata of this object occupies in memory. - metadata_size: long; - // Number of clients using the objects. - ref_count: int; - // Unix epoch of when this object was created. - create_time: long; - // How long creation of this object took. - construct_duration: long; - // Hash of the object content. If the object is not sealed yet this is - // an empty string. - digest: string; - // Specifies if this object was deleted or added. - is_deletion: bool; -} - -root_type TaskInfo; - -table TaskExecutionDependencies { - // A list of object IDs representing this task's dependencies at execution - // time. - execution_dependencies: [string]; -} - -root_type TaskExecutionDependencies; - -table SubscribeToNotificationsReply { - // The object ID of the object that the notification is about. - object_id: string; - // The size of the object. - object_size: long; - // The IDs of the managers that contain this object. - manager_ids: [string]; -} - -root_type SubscribeToNotificationsReply; - -table TaskReply { - // The task ID of the task that the message is about. - task_id: string; - // The state of the task. This is encoded as a bit mask of scheduling_state - // enum values in task.h. - state: long; - // A local scheduler ID. - local_scheduler_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // A string of bytes representing the task specification. - task_spec: string; - // The number of times the task was spilled back by local schedulers. - spillback_count: long; - // A boolean representing whether the update was successful. This field - // should only be used for test-and-set operations. - updated: bool; -} - -root_type TaskReply; - -table SubscribeToDBClientTableReply { - // The db client ID of the client that the message is about. - db_client_id: string; - // The type of the client. - client_type: string; - // If the client is a local scheduler, this is the address of the plasma - // manager that the local scheduler is connected to. Otherwise, it is empty. - manager_address: string; - // True if the message is about the addition of a client and false if it is - // about the deletion of a client. - is_insertion: bool; -} - -root_type SubscribeToDBClientTableReply; - -table LocalSchedulerInfoMessage { - // The db client ID of the client that the message is about. - db_client_id: string; - // The total number of workers that are connected to this local scheduler. - total_num_workers: long; - // The number of tasks queued in this local scheduler. - task_queue_length: long; - // The number of workers that are available and waiting for tasks. - available_workers: long; - // The resources generally available to this local scheduler. - static_resources: [ResourcePair]; - // The resources currently available to this local scheduler. - dynamic_resources: [ResourcePair]; - // Whether the local scheduler is dead. If true, then all other fields - // besides `db_client_id` will not be set. - is_dead: bool; -} - -root_type LocalSchedulerInfoMessage; - -table ResultTableReply { - // The task ID of the task that created the object. - task_id: string; - // Whether the task created the object through a ray.put. - is_put: bool; - // The size of the object created. - data_size: long; - // The hash of the object created. - hash: string; -} - -root_type ResultTableReply; - -table DriverTableMessage { - // The driver ID of the driver that died. - driver_id: string; -} - -table ActorCreationNotification { - // The ID of the actor that was created. - actor_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the local scheduler that created the actor. - local_scheduler_id: string; -} diff --git a/src/common/io.cc b/src/common/io.cc deleted file mode 100644 index 1999b7054669..000000000000 --- a/src/common/io.cc +++ /dev/null @@ -1,416 +0,0 @@ -#include "io.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common.h" -#include "event_loop.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -int bind_inet_sock(const int port, bool shall_listen) { - struct sockaddr_in name; - int socket_fd = socket(PF_INET, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for port " << port; - return -1; - } - name.sin_family = AF_INET; - name.sin_port = htons(port); - name.sin_addr.s_addr = htonl(INADDR_ANY); - int on = 1; - /* TODO(pcm): http://stackoverflow.com/q/1150635 */ - if (ioctl(socket_fd, FIONBIO, (char *) &on) < 0) { - RAY_LOG(ERROR) << "ioctl failed"; - close(socket_fd); - return -1; - } - int *const pon = (int *const) & on; - if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, pon, sizeof(on)) < 0) { - RAY_LOG(ERROR) << "setsockopt failed for port " << port; - close(socket_fd); - return -1; - } - if (bind(socket_fd, (struct sockaddr *) &name, sizeof(name)) < 0) { - RAY_LOG(ERROR) << "Bind failed for port " << port; - close(socket_fd); - return -1; - } - if (shall_listen && listen(socket_fd, 128) == -1) { - RAY_LOG(ERROR) << "Could not listen to socket " << port; - close(socket_fd); - return -1; - } - return socket_fd; -} - -int bind_ipc_sock(const char *socket_pathname, bool shall_listen) { - struct sockaddr_un socket_address; - int socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - /* Tell the system to allow the port to be reused. */ - int on = 1; - if (setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, (char *) &on, - sizeof(on)) < 0) { - RAY_LOG(ERROR) << "setsockopt failed for pathname " << socket_pathname; - close(socket_fd); - return -1; - } - - unlink(socket_pathname); - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - close(socket_fd); - return -1; - } - strncpy(socket_address.sun_path, socket_pathname, - strlen(socket_pathname) + 1); - - if (bind(socket_fd, (struct sockaddr *) &socket_address, - sizeof(socket_address)) != 0) { - RAY_LOG(ERROR) << "Bind failed for pathname " << socket_pathname; - close(socket_fd); - return -1; - } - if (shall_listen && listen(socket_fd, 128) == -1) { - RAY_LOG(ERROR) << "Could not listen to socket " << socket_pathname; - close(socket_fd); - return -1; - } - return socket_fd; -} - -int connect_ipc_sock_retry(const char *socket_pathname, - int num_retries, - int64_t timeout) { - /* Pick the default values if the user did not specify. */ - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - - RAY_CHECK(socket_pathname); - int fd = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - fd = connect_ipc_sock(socket_pathname); - if (fd >= 0) { - break; - } - if (num_attempts == 0) { - RAY_LOG(ERROR) << "Connection to socket failed for pathname " - << socket_pathname; - } - /* Sleep for timeout milliseconds. */ - usleep(timeout * 1000); - } - /* If we could not connect to the socket, exit. */ - if (fd == -1) { - RAY_LOG(FATAL) << "Could not connect to socket " << socket_pathname; - } - return fd; -} - -int connect_ipc_sock(const char *socket_pathname) { - struct sockaddr_un socket_address; - int socket_fd; - - socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (strlen(socket_pathname) + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - return -1; - } - strncpy(socket_address.sun_path, socket_pathname, - strlen(socket_pathname) + 1); - - if (connect(socket_fd, (struct sockaddr *) &socket_address, - sizeof(socket_address)) != 0) { - close(socket_fd); - return -1; - } - - return socket_fd; -} - -int connect_inet_sock_retry(const char *ip_addr, - int port, - int num_retries, - int64_t timeout) { - /* Pick the default values if the user did not specify. */ - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - - RAY_CHECK(ip_addr); - int fd = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - fd = connect_inet_sock(ip_addr, port); - if (fd >= 0) { - break; - } - if (num_attempts == 0) { - RAY_LOG(ERROR) << "Connection to socket failed for address " << ip_addr - << ":" << port; - } - /* Sleep for timeout milliseconds. */ - usleep(timeout * 1000); - } - /* If we could not connect to the socket, exit. */ - if (fd == -1) { - RAY_LOG(FATAL) << "Could not connect to address " << ip_addr << ":" << port; - } - return fd; -} - -int connect_inet_sock(const char *ip_addr, int port) { - int fd = socket(PF_INET, SOCK_STREAM, 0); - if (fd < 0) { - RAY_LOG(ERROR) << "socket() failed for address " << ip_addr << ":" << port; - return -1; - } - - struct hostent *manager = gethostbyname(ip_addr); /* TODO(pcm): cache this */ - if (!manager) { - RAY_LOG(ERROR) << "Failed to get hostname from address " << ip_addr << ":" - << port; - close(fd); - return -1; - } - - struct sockaddr_in addr; - addr.sin_family = AF_INET; - memcpy(&addr.sin_addr.s_addr, manager->h_addr_list[0], manager->h_length); - addr.sin_port = htons(port); - - if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) != 0) { - close(fd); - return -1; - } - return fd; -} - -int accept_client(int socket_fd) { - int client_fd = accept(socket_fd, NULL, NULL); - if (client_fd < 0) { - RAY_LOG(ERROR) << "Error reading from socket."; - return -1; - } - return client_fd; -} - -int write_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - /* While we haven't written the whole message, write to the file - * descriptor, advance the cursor, and decrease the amount left to write. */ - nbytes = write(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -int do_write_message(int fd, int64_t type, int64_t length, uint8_t *bytes) { - int64_t version = RayConfig::instance().ray_protocol_version(); - int closed; - closed = write_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *) &type, sizeof(type)); - if (closed) { - return closed; - } - closed = write_bytes(fd, (uint8_t *) &length, sizeof(length)); - if (closed) { - return closed; - } - closed = write_bytes(fd, bytes, length * sizeof(char)); - if (closed) { - return closed; - } - return 0; -} - -int write_message(int fd, - int64_t type, - int64_t length, - uint8_t *bytes, - std::mutex *mutex) { - if (mutex != NULL) { - std::unique_lock guard(*mutex); - return do_write_message(fd, type, length, bytes); - } else { - return do_write_message(fd, type, length, bytes); - } -} - -int read_bytes(int fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - /* Termination condition: EOF or read 'length' bytes total. */ - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - nbytes = read(fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; /* Errno will be set. */ - } else if (0 == nbytes) { - /* Encountered early EOF. */ - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - - return 0; -} - -void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes) { - int64_t version; - int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - goto disconnected; - } - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); - closed = read_bytes(fd, (uint8_t *) type, sizeof(*type)); - if (closed) { - goto disconnected; - } - closed = read_bytes(fd, (uint8_t *) length, sizeof(*length)); - if (closed) { - goto disconnected; - } - *bytes = (uint8_t *) malloc(*length * sizeof(uint8_t)); - closed = read_bytes(fd, *bytes, *length); - if (closed) { - free(*bytes); - goto disconnected; - } - return; - -disconnected: - /* Handle the case in which the socket is closed. */ - *type = static_cast(CommonMessageType::DISCONNECT_CLIENT); - *length = 0; - *bytes = NULL; - return; -} - -uint8_t *read_message_async(event_loop *loop, int sock) { - int64_t size; - int error = read_bytes(sock, (uint8_t *) &size, sizeof(int64_t)); - if (error < 0) { - /* The other side has closed the socket. */ - RAY_LOG(DEBUG) << "Socket has been closed, or some other error has " - << "occurred."; - if (loop != NULL) { - event_loop_remove_file(loop, sock); - } - close(sock); - return NULL; - } - uint8_t *message = (uint8_t *) malloc(size); - error = read_bytes(sock, message, size); - if (error < 0) { - /* The other side has closed the socket. */ - RAY_LOG(DEBUG) << "Socket has been closed, or some other error has " - << "occurred."; - if (loop != NULL) { - event_loop_remove_file(loop, sock); - } - close(sock); - return NULL; - } - return message; -} - -int64_t read_vector(int fd, int64_t *type, std::vector &buffer) { - int64_t version; - int closed = read_bytes(fd, (uint8_t *) &version, sizeof(version)); - if (closed) { - goto disconnected; - } - RAY_CHECK(version == RayConfig::instance().ray_protocol_version()); - int64_t length; - closed = read_bytes(fd, (uint8_t *) type, sizeof(*type)); - if (closed) { - goto disconnected; - } - closed = read_bytes(fd, (uint8_t *) &length, sizeof(length)); - if (closed) { - goto disconnected; - } - if (static_cast(length) > buffer.size()) { - buffer.resize(length); - } - closed = read_bytes(fd, buffer.data(), length); - if (closed) { - goto disconnected; - } - return length; -disconnected: - /* Handle the case in which the socket is closed. */ - *type = static_cast(CommonMessageType::DISCONNECT_CLIENT); - return 0; -} - -void write_log_message(int fd, const char *message) { - /* Account for the \0 at the end of the string. */ - do_write_message(fd, static_cast(CommonMessageType::LOG_MESSAGE), - strlen(message) + 1, (uint8_t *) message); -} - -char *read_log_message(int fd) { - uint8_t *bytes; - int64_t type; - int64_t length; - read_message(fd, &type, &length, &bytes); - RAY_CHECK(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - return (char *) bytes; -} diff --git a/src/common/io.h b/src/common/io.h deleted file mode 100644 index 3f976445aeb0..000000000000 --- a/src/common/io.h +++ /dev/null @@ -1,228 +0,0 @@ -#ifndef IO_H -#define IO_H - -#include -#include - -#include -#include - -struct aeEventLoop; -typedef aeEventLoop event_loop; - -enum class CommonMessageType : int32_t { - /** Disconnect a client. */ - DISCONNECT_CLIENT, - /** Log a message from a client. */ - LOG_MESSAGE, - /** Submit a task to the local scheduler. */ - SUBMIT_TASK, -}; - -/* Helper functions for socket communication. */ - -/** - * Binds to an Internet socket at the given port. Removes any existing file at - * the pathname. Returns a non-blocking file descriptor for the socket, or -1 - * if an error occurred. - * - * @note Since the returned file descriptor is non-blocking, it is not - * recommended to use the Linux read and write calls directly, since these - * might read or write a partial message. Instead, use the provided - * write_message and read_message methods. - * - * @param port The port to bind to. - * @param shall_listen Are we also starting to listen on the socket? - * @return A non-blocking file descriptor for the socket, or -1 if an error - * occurs. - */ -int bind_inet_sock(const int port, bool shall_listen); - -/** - * Binds to a Unix domain streaming socket at the given - * pathname. Removes any existing file at the pathname. - * - * @param socket_pathname The pathname for the socket. - * @param shall_listen Are we also starting to listen on the socket? - * @return A blocking file descriptor for the socket, or -1 if an error - * occurs. - */ -int bind_ipc_sock(const char *socket_pathname, bool shall_listen); - -/** - * Connect to a Unix domain streaming socket at the given - * pathname. - * - * @param socket_pathname The pathname for the socket. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_ipc_sock(const char *socket_pathname); - -/** - * Connect to a Unix domain streaming socket at the given - * pathname, or fail after some number of retries. - * - * @param socket_pathname The pathname for the socket. - * @param num_retries The number of times to retry the connection - * before exiting. If -1 is provided, then this defaults to - * num_connect_attempts. - * @param timeout The number of milliseconds to wait in between - * retries. If -1 is provided, then this defaults to - * connect_timeout_milliseconds. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_ipc_sock_retry(const char *socket_pathname, - int num_retries, - int64_t timeout); - -/** - * Connect to an Internet socket at the given address and port. - * - * @param ip_addr The IP address to connect to. - * @param port The port number to connect to. - * - * @param socket_pathname The pathname for the socket. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_inet_sock(const char *ip_addr, int port); - -/** - * Connect to an Internet socket at the given address and port, or fail after - * some number of retries. - * - * @param ip_addr The IP address to connect to. - * @param port The port number to connect to. - * @param num_retries The number of times to retry the connection - * before exiting. If -1 is provided, then this defaults to - * num_connect_attempts. - * @param timeout The number of milliseconds to wait in between - * retries. If -1 is provided, then this defaults to - * connect_timeout_milliseconds. - * @return A file descriptor for the socket, or -1 if an error occurred. - */ -int connect_inet_sock_retry(const char *ip_addr, - int port, - int num_retries, - int64_t timeout); - -/** - * Accept a new client connection on the given socket - * descriptor. Returns a descriptor for the new socket. - */ -int accept_client(int socket_fd); - -/* Reading and writing data. */ - -/** - * Write a sequence of bytes on a file descriptor. The bytes should then be read - * by read_message. - * - * @param fd The file descriptor to write to. It can be non-blocking. - * @param version The protocol version. - * @param type The type of the message to send. - * @param length The size in bytes of the bytes parameter. - * @param bytes The address of the message to send. - * @param mutex If not NULL, the whole write operation will be locked - * with this mutex, otherwise do nothing. - * @return int Whether there was an error while writing. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int write_message(int fd, - int64_t type, - int64_t length, - uint8_t *bytes, - std::mutex *mutex = NULL); - -/** - * Read a sequence of bytes written by write_message from a file descriptor. - * This allocates space for the message. - * - * @note The caller must free the memory. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param type The type of the message that is read will be written at this - * address. If there was an error while reading, this will be - * DISCONNECT_CLIENT. - * @param length The size in bytes of the message that is read will be written - * at this address. This size does not include the bytes used to encode - * the type and length. If there was an error while reading, this will - * be 0. - * @param bytes The address at which to write the pointer to the bytes that are - * read and allocated by this function. If there was an error while - * reading, this will be NULL. - * @return Void. - */ -void read_message(int fd, int64_t *type, int64_t *length, uint8_t **bytes); - -/** - * Read a message from a file descriptor and remove the file descriptor from the - * event loop if there is an error. This will actually do two reads. The first - * read reads sizeof(int64_t) bytes to determine the number of bytes to read in - * the next read. - * - * @param loop: The event loop. - * @param sock: The file descriptor to read from. - * @return A byte buffer contining the message or NULL if there was an - * error. The buffer needs to be freed by the user. - */ -uint8_t *read_message_async(event_loop *loop, int sock); - -/** - * Read a sequence of bytes written by write_message from a file descriptor. - * This does not allocate space for the message if the provided buffer is - * large enough and can therefore often avoid allocations. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param type The type of the message that is read will be written at this - * address. If there was an error while reading, this will be - * DISCONNECT_CLIENT. - * @param buffer The array the message will be written to. If it is not - * large enough to hold the message, it will be enlarged by read_vector. - * @return Number of bytes of the message that were read. This size does not - * include the bytes used to encode the type and length. If there was - * an error while reading, this will be 0. - */ -int64_t read_vector(int fd, int64_t *type, std::vector &buffer); - -/** - * Write a null-terminated string to a file descriptor. - */ -void write_log_message(int fd, const char *message); - -/** - * Reads a null-terminated string from the file descriptor that has been - * written by write_log_message. Allocates and returns a pointer to the string. - * NOTE: Caller must free the memory! - */ -char *read_log_message(int fd); - -/** - * Read a sequence of bytes from a file descriptor into a buffer. This will - * block until one of the following happens: (1) there is an error (2) end of - * file, or (3) all length bytes have been written. - * - * @note The buffer pointed to by cursor must already have length number of - * bytes allocated before calling this method. - * - * @param fd The file descriptor to read from. It can be non-blocking. - * @param cursor The cursor pointing to the beginning of the buffer. - * @param length The size of the byte sequence to read. - * @return int Whether there was an error while reading. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int read_bytes(int fd, uint8_t *cursor, size_t length); - -/** - * Write a sequence of bytes into a file descriptor. This will block until one - * of the following happens: (1) there is an error (2) end of file, or (3) all - * length bytes have been written. - * - * @param fd The file descriptor to write to. It can be non-blocking. - * @param cursor The cursor pointing to the beginning of the bytes to send. - * @param length The size of the bytes sequence to write. - * @return int Whether there was an error while writing. 0 corresponds to - * success and -1 corresponds to an error (errno will be set). - */ -int write_bytes(int fd, uint8_t *cursor, size_t length); - -#endif /* IO_H */ diff --git a/src/common/logging.cc b/src/common/logging.cc deleted file mode 100644 index 9802dd3d03f3..000000000000 --- a/src/common/logging.cc +++ /dev/null @@ -1,107 +0,0 @@ -#include "logging.h" - -#include -#include -#include - -#include - -#include "state/redis.h" -#include "io.h" -#include -#include - -static const char *log_levels[5] = {"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}; -static const char *log_fmt = - "HMSET log:%s:%s log_level %s event_type %s message %s timestamp %s"; - -struct RayLoggerImpl { - /* String that identifies this client type. */ - const char *client_type; - /* Suppress all log messages below this level. */ - int log_level; - /* Whether or not we have a direct connection to Redis. */ - int is_direct; - /* Either a db_handle or a socket to a process with a db_handle, - * depending on the is_direct flag. */ - void *conn; -}; - -RayLogger *RayLogger_init(const char *client_type, - int log_level, - int is_direct, - void *conn) { - RayLogger *logger = (RayLogger *) malloc(sizeof(RayLogger)); - logger->client_type = client_type; - logger->log_level = log_level; - logger->is_direct = is_direct; - logger->conn = conn; - return logger; -} - -void RayLogger_free(RayLogger *logger) { - free(logger); -} - -void RayLogger_log(RayLogger *logger, - int log_level, - const char *event_type, - const char *message) { - if (log_level < logger->log_level) { - return; - } - if (log_level < RAY_LOG_DEBUG || log_level > RAY_LOG_FATAL) { - return; - } - struct timeval tv; - gettimeofday(&tv, NULL); - std::string timestamp = - std::to_string(tv.tv_sec) + "." + std::to_string(tv.tv_usec); - - /* Find number of bytes that would have been written for formatted_message - * size */ - size_t formatted_message_size = - std::snprintf(nullptr, 0, log_fmt, timestamp.c_str(), "%b", - log_levels[log_level], event_type, message, - timestamp.c_str()) + - 1; - /* Fill out everything except the client ID, which is binary data. */ - char formatted_message[formatted_message_size]; - std::snprintf(formatted_message, formatted_message_size, log_fmt, - timestamp.c_str(), "%b", log_levels[log_level], event_type, - message, timestamp.c_str()); - - if (logger->is_direct) { - DBHandle *db = (DBHandle *) logger->conn; - /* Fill in the client ID and send the message to Redis. */ - - redisAsyncContext *context = get_redis_context(db, db->client); - - int status = - redisAsyncCommand(context, NULL, NULL, formatted_message, - (char *) db->client.data(), sizeof(db->client)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error while logging message to log table"); - } - } else { - /* If we don't own a Redis connection, we leave our client - * ID to be filled in by someone else. */ - int *socket_fd = (int *) logger->conn; - write_log_message(*socket_fd, formatted_message); - } -} - -void RayLogger_log_event(DBHandle *db, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp) { - std::string timestamp_string = std::to_string(timestamp); - int status = redisAsyncCommand(db->context, NULL, NULL, "ZADD %b %s %b", key, - key_length, timestamp_string.c_str(), value, - value_length); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error while logging message to event log"); - } -} diff --git a/src/common/logging.h b/src/common/logging.h deleted file mode 100644 index 1fa57a60c712..000000000000 --- a/src/common/logging.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef LOGGING_H -#define LOGGING_H - -#define RAY_LOG_VERBOSE -1 -#define RAY_LOG_DEBUG 0 -#define RAY_LOG_INFO 1 -#define RAY_LOG_WARNING 2 -#define RAY_LOG_ERROR 3 -#define RAY_LOG_FATAL 4 - -/* Entity types. */ -#define RAY_FUNCTION "FUNCTION" -#define RAY_OBJECT "OBJECT" -#define RAY_TASK "TASK" - -#include "state/db.h" - -typedef struct RayLoggerImpl RayLogger; - -/* Initialize a Ray logger for the given client type and logging level. If the - * is_direct flag is set, the logger will treat the given connection as a - * direct connection to the log. Otherwise, it will treat it as a socket to - * another process with a connection to the log. - * NOTE: User is responsible for freeing the returned logger. */ -RayLogger *RayLogger_init(const char *client_type, - int log_level, - int is_direct, - void *conn); - -/* Free the logger. This does not free the connection to the log. */ -void RayLogger_free(RayLogger *logger); - -/* Log an event at the given log level with the given event_type. - * NOTE: message cannot contain spaces! JSON format is recommended. - * TODO: Support spaces in messages. */ -void RayLogger_log(RayLogger *logger, - int log_level, - const char *event_type, - const char *message); - -/** - * Log an event to the event log. - * - * @param db The database handle. - * @param key The key in Redis to store the event in. - * @param key_length The length of the key. - * @param value The value to log. - * @param value_length The length of the value. - * @return Void. - */ -void RayLogger_log_event(DBHandle *db, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double time); - -#endif /* LOGGING_H */ diff --git a/src/common/net.cc b/src/common/net.cc deleted file mode 100644 index 3f2aaf6fa94e..000000000000 --- a/src/common/net.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "net.h" - -#include - -#include - -#include "common.h" - -int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port) { - char port_str[6]; - int parsed = sscanf(ip_addr_port, "%15[0-9.]:%5[0-9]", ip_addr, port_str); - if (parsed != 2) { - return -1; - } - *port = atoi(port_str); - return 0; -} - -/* Return true if the ip address is valid. */ -bool valid_ip_address(const std::string &ip_address) { - struct sockaddr_in sa; - int result = inet_pton(AF_INET, ip_address.c_str(), &sa.sin_addr); - return result == 1; -} diff --git a/src/common/net.h b/src/common/net.h deleted file mode 100644 index 109cdf3fa1f3..000000000000 --- a/src/common/net.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef NET_H -#define NET_H - -/* Helper function to parse a string of the form : into the - * given ip_addr and port pointers. The ip_addr buffer must already be - * allocated. Return 0 upon success and -1 upon failure. */ -int parse_ip_addr_port(const char *ip_addr_port, char *ip_addr, int *port); - -#endif /* NET_H */ diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc deleted file mode 100644 index d594f74effef..000000000000 --- a/src/common/redis_module/ray_redis_module.cc +++ /dev/null @@ -1,1886 +0,0 @@ -#include - -#include "common_protocol.h" -#include "format/common_generated.h" -#include "ray/gcs/format/gcs_generated.h" -#include "ray/id.h" -#include "redis_string.h" -#include "redismodule.h" -#include "task.h" - -#if RAY_USE_NEW_GCS -// Under this flag, ray-project/credis will be loaded. Specifically, via -// "path/redis-server --loadmodule --loadmodule " (dlopen() under the hood) will a definition of "module" -// be supplied. -// -// All commands in this file that depend on "module" must be wrapped by "#if -// RAY_USE_NEW_GCS", until we switch to this launch configuration as the -// default. -#include "chain_module.h" -extern RedisChainModule module; -#endif - -// Various tables are maintained in redis: -// -// == OBJECT TABLE == -// -// This consists of two parts: -// - The object location table, indexed by OL:object_id, which is the set of -// plasma manager indices that have access to the object. -// (In redis this is represented by a zset (sorted set).) -// -// - The object info table, indexed by OI:object_id, which is a hashmap of: -// "hash" -> the hash of the object, -// "data_size" -> the size of the object in bytes, -// "task" -> the task ID that generated this object. -// "is_put" -> 0 or 1. -// -// == TASK TABLE == -// -// It maps each TT:task_id to a hash: -// "state" -> the state of the task, encoded as a bit mask of scheduling_state -// enum values in task.h, -// "local_scheduler_id" -> the ID of the local scheduler the task is assigned -// to, -// "TaskSpec" -> serialized bytes of a TaskInfo (defined in common.fbs), which -// describes the details this task. -// -// See also the definition of TaskReply in common.fbs. - -#define OBJECT_INFO_PREFIX "OI:" -#define OBJECT_LOCATION_PREFIX "OL:" -#define OBJECT_NOTIFICATION_PREFIX "ON:" -#define TASK_PREFIX "TT:" -#define OBJECT_BCAST "BCAST" - -#define OBJECT_CHANNEL_PREFIX "OC:" - -#define CHECK_ERROR(STATUS, MESSAGE) \ - if ((STATUS) == REDISMODULE_ERR) { \ - return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ - } - -/// Parse a Redis string into a TablePubsub channel. -TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) { - long long pubsub_channel_long; - RAY_CHECK(RedisModule_StringToLongLong( - pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK) - << "Pubsub channel must be a valid TablePubsub"; - auto pubsub_channel = static_cast(pubsub_channel_long); - RAY_CHECK(pubsub_channel >= TablePubsub::MIN && - pubsub_channel <= TablePubsub::MAX) - << "Pubsub channel must be a valid TablePubsub"; - return pubsub_channel; -} - -/// Format a pubsub channel for a specific key. pubsub_channel_str should -/// contain a valid TablePubsub. -RedisModuleString *FormatPubsubChannel( - RedisModuleCtx *ctx, - const RedisModuleString *pubsub_channel_str, - const RedisModuleString *id) { - // Format the pubsub channel enum to a string. TablePubsub_MAX should be more - // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; - sprintf(pubsub_channel, "%d", - static_cast(ParseTablePubsub(pubsub_channel_str))); - return RedisString_Format(ctx, "%s:%S", pubsub_channel, id); -} - -// TODO(swang): This helper function should be deprecated by the version below, -// which uses enums for table prefixes. -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - const char *prefix, - RedisModuleString *keyname, - int mode, - RedisModuleString **mutated_key_str) { - RedisModuleString *prefixed_keyname = - RedisString_Format(ctx, "%s%S", prefix, keyname); - // Pass out the key being mutated, should the caller request so. - if (mutated_key_str != nullptr) { - *mutated_key_str = prefixed_keyname; - } - RedisModuleKey *key = - (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); - return key; -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - RedisModuleString *prefix_enum, - RedisModuleString *keyname, - int mode, - RedisModuleString **mutated_key_str) { - long long prefix_long; - RAY_CHECK(RedisModule_StringToLongLong(prefix_enum, &prefix_long) == - REDISMODULE_OK) - << "Prefix must be a valid TablePrefix"; - auto prefix = static_cast(prefix_long); - RAY_CHECK(prefix != TablePrefix::UNUSED) - << "This table has no prefix registered"; - RAY_CHECK(prefix >= TablePrefix::MIN && prefix <= TablePrefix::MAX) - << "Prefix must be a valid TablePrefix"; - return OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, - mutated_key_str); -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - const char *prefix, - RedisModuleString *keyname, - int mode) { - return OpenPrefixedKey(ctx, prefix, keyname, mode, - /*mutated_key_str=*/nullptr); -} - -RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, - RedisModuleString *prefix_enum, - RedisModuleString *keyname, - int mode) { - return OpenPrefixedKey(ctx, prefix_enum, keyname, mode, - /*mutated_key_str=*/nullptr); -} - -/// Open the key used to store the channels that should be published to when an -/// update happens at the given keyname. -RedisModuleKey *OpenBroadcastKey(RedisModuleCtx *ctx, - RedisModuleString *pubsub_channel_str, - RedisModuleString *keyname, - int mode) { - RedisModuleString *channel = - FormatPubsubChannel(ctx, pubsub_channel_str, keyname); - RedisModuleString *prefixed_keyname = - RedisString_Format(ctx, "BCAST:%S", channel); - RedisModuleKey *key = - (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); - return key; -} - -/** - * This is a helper method to convert a redis module string to a flatbuffer - * string. - * - * @param fbb The flatbuffer builder. - * @param redis_string The redis string. - * @return The flatbuffer string. - */ -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, - RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); -} - -/** - * Publish a notification to a client's notification channel about an insertion - * or deletion to the db client table. - * - * TODO(swang): Use flatbuffers for the notification message. - * The format for the published notification is: - * : - * If no manager address is provided, manager_address will be set to ":". If - * is_insertion is true, then the last field will be "1", else "0". - * - * @param ctx The Redis context. - * @param ray_client_id The ID of the database client that was inserted or - * deleted. - * @param client_type The type of client that was inserted or deleted. - * @param manager_address An optional secondary address for the object manager - * associated with this database client. - * @param is_insertion A boolean that's true if the update was an insertion and - * false if deletion. - * @return True if the publish was successful and false otherwise. - */ -bool PublishDBClientNotification(RedisModuleCtx *ctx, - RedisModuleString *ray_client_id, - RedisModuleString *client_type, - RedisModuleString *manager_address, - bool is_insertion) { - /* Construct strings to publish on the db client channel. */ - RedisModuleString *channel_name = - RedisModule_CreateString(ctx, "db_clients", strlen("db_clients")); - /* Construct the flatbuffers object to publish over the channel. */ - flatbuffers::FlatBufferBuilder fbb; - /* Use an empty aux address if one is not passed in. */ - flatbuffers::Offset manager_address_str; - if (manager_address != NULL) { - manager_address_str = RedisStringToFlatbuf(fbb, manager_address); - } else { - manager_address_str = fbb.CreateString("", strlen("")); - } - /* Create the flatbuffers message. */ - auto message = CreateSubscribeToDBClientTableReply( - fbb, RedisStringToFlatbuf(fbb, ray_client_id), - RedisStringToFlatbuf(fbb, client_type), manager_address_str, - is_insertion); - fbb.Finish(message); - /* Create a Redis string to publish by serializing the flatbuffers object. */ - RedisModuleString *client_info = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - /* Publish the client info on the db client channel. */ - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, client_info); - return (reply != NULL); -} - -/** - * Register a client with Redis. This is called from a client with the command: - * - * RAY.CONNECT - * ... - * - * The command can take an arbitrary number of pairs of field names and keys, - * and these will be stored in a hashmap associated with this client. Several - * fields are singled out for special treatment: - * - * manager_address: This is provided by local schedulers and plasma - * managers and should be the address of the plasma manager that the - * client is associated with. This is published to the "db_clients" - * channel by the RAY.CONNECT command. - * - * @param ray_client_id The db client ID of the client. - * @param node_ip_address The IP address of the node the client is on. - * @param client_type The type of the client (e.g., plasma_manager). - * @return OK if the operation was successful. - */ -int Connect_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 4) { - return RedisModule_WrongArity(ctx); - } - if (argc % 2 != 0) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *ray_client_id = argv[1]; - RedisModuleString *node_ip_address = argv[2]; - RedisModuleString *client_type = argv[3]; - - /* Add this client to the Ray db client table. */ - RedisModuleKey *db_client_table_key = - OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); - - if (RedisModule_KeyType(db_client_table_key) != REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "Client already exists"); - } - - /* This will be used to construct a publish message. */ - RedisModuleString *manager_address = NULL; - RedisModuleString *manager_address_key = RedisModule_CreateString( - ctx, "manager_address", strlen("manager_address")); - RedisModuleString *deleted = RedisModule_CreateString(ctx, "0", strlen("0")); - - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "ray_client_id", ray_client_id, "node_ip_address", - node_ip_address, "client_type", client_type, "deleted", - deleted, NULL); - - for (int i = 4; i < argc; i += 2) { - RedisModuleString *key = argv[i]; - RedisModuleString *value = argv[i + 1]; - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_NONE, key, value, - NULL); - if (RedisModule_StringCompare(key, manager_address_key) == 0) { - manager_address = value; - } - } - /* Clean up. */ - if (!PublishDBClientNotification(ctx, ray_client_id, client_type, - manager_address, true)) { - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Remove a client from Redis. This is called from a client with the command: - * - * RAY.DISCONNECT - * - * This method also publishes a notification to all subscribers to the - * db_clients channel. The notification consists of a message of the form ":". - * - * @param ray_client_id The db client ID of the client. - * @return OK if the operation was successful. - */ -int Disconnect_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *ray_client_id = argv[1]; - - /* Get the client type. */ - RedisModuleKey *db_client_table_key = - OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); - - RedisModuleString *deleted_string; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, "deleted", - &deleted_string, NULL); - long long deleted; - int parsed = RedisModule_StringToLongLong(deleted_string, &deleted); - if (parsed != REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "Unable to parse deleted field"); - } - - bool published = true; - if (deleted == 0) { - /* Remove the client from the client table. */ - RedisModuleString *deleted = - RedisModule_CreateString(ctx, "1", strlen("1")); - RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "deleted", deleted, NULL); - - RedisModuleString *client_type; - RedisModuleString *manager_address; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "client_type", &client_type, "manager_address", - &manager_address, NULL); - - /* Publish the deletion notification on the db client channel. */ - published = PublishDBClientNotification(ctx, ray_client_id, client_type, - manager_address, false); - } - - if (!published) { - /* Return an error message if we weren't able to publish the deletion - * notification. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Lookup an entry in the object table. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_LOOKUP - * - * @param object_id A string representing the object ID. - * @return A list, possibly empty, of plasma manager IDs that are listed in the - * object table as having the object. If there was no entry found in - * the object table, returns nil. - */ -int ObjectTableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleKey *key = - OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, argv[1], REDISMODULE_READ); - - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - /* Return nil if no entry was found. */ - return RedisModule_ReplyWithNull(ctx); - } - if (RedisModule_ValueLength(key) == 0) { - /* Return empty list if there are no managers. */ - return RedisModule_ReplyWithArray(ctx, 0); - } - - CHECK_ERROR( - RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - - RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN); - int num_results = 0; - do { - RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL); - RedisModule_ReplyWithString(ctx, curr); - num_results += 1; - } while (RedisModule_ZsetRangeNext(key)); - RedisModule_ReplySetArrayLength(ctx, num_results); - - return REDISMODULE_OK; -} - -/** - * Publish a notification to a client's object notification channel if at least - * one manager is listed as having the object in the object table. - * - * @param ctx The Redis context. - * @param client_id The ID of the client that is being notified. - * @param object_id The object ID of interest. - * @param key The opened key for the entry in the object table corresponding to - * the object ID of interest. - * @return True if the publish was successful and false otherwise. - */ -bool PublishObjectNotification(RedisModuleCtx *ctx, - RedisModuleString *client_id, - RedisModuleString *object_id, - RedisModuleString *data_size, - RedisModuleKey *key) { - flatbuffers::FlatBufferBuilder fbb; - - long long data_size_value; - if (RedisModule_StringToLongLong(data_size, &data_size_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "data_size must be integer"); - } - - std::vector> manager_ids; - CHECK_ERROR( - RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - /* Loop over the managers in the object table for this object ID. */ - do { - RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL); - manager_ids.push_back(RedisStringToFlatbuf(fbb, curr)); - } while (RedisModule_ZsetRangeNext(key)); - - auto message = CreateSubscribeToNotificationsReply( - fbb, RedisStringToFlatbuf(fbb, object_id), data_size_value, - fbb.CreateVector(manager_ids)); - fbb.Finish(message); - - /* Publish the notification to the clients notification channel. - * TODO(rkn): These notifications could be batched together. */ - RedisModuleString *channel_name = - RedisString_Format(ctx, "%s%S", OBJECT_CHANNEL_PREFIX, client_id); - - RedisModuleString *payload = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, payload); - if (reply == NULL) { - return false; - } - return true; -} - -// NOTE(pcmoritz): This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -int PublishTaskTableAdd(RedisModuleCtx *ctx, - RedisModuleString *id, - RedisModuleString *data) { - const char *buf = RedisModule_StringPtrLen(data, NULL); - auto message = flatbuffers::GetRoot(buf); - RAY_CHECK(message != nullptr); - - if (message->scheduling_state() == SchedulingState::WAITING || - message->scheduling_state() == SchedulingState::SCHEDULED) { - /* Build the PUBLISH topic and message for task table subscribers. The - * topic - * is a string in the format "TASK_PREFIX::". - * The - * message is a serialized SubscribeToTasksReply flatbuffer object. */ - std::string state = - std::to_string(static_cast(message->scheduling_state())); - RedisModuleString *publish_topic = RedisString_Format( - ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(), - sizeof(DBClientID), state.c_str()); - - /* Construct the flatbuffers object for the payload. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto msg = - CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, id), - static_cast(message->scheduling_state()), - fbb.CreateString(message->scheduler_id()), - fbb.CreateString(message->execution_dependencies()), - fbb.CreateString(message->task_info()), - message->spillback_count(), true /* not used */); - fbb.Finish(msg); - - RedisModuleString *publish_message = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); - - /* See how many clients received this publish. */ - long long num_clients = RedisModule_CallReplyInteger(reply); - RAY_CHECK(num_clients <= 1) << "Published to " << num_clients - << " clients."; - } - return RedisModule_ReplyWithSimpleString(ctx, "OK"); -} - -/// Publish a notification for a new entry at a key. This publishes a -/// notification to all subscribers of the table, as well as every client that -/// has requested notifications for this key. -/// -/// \param pubsub_channel_str The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key that the notification is about. -/// \param data The data to publish. -/// \return OK if there is no error during a publish. -int PublishTableAdd(RedisModuleCtx *ctx, - RedisModuleString *pubsub_channel_str, - RedisModuleString *id, - RedisModuleString *data) { - // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - - // Write the data back to any subscribers that are listening to all table - // notifications. - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, - fbb.GetBufferPointer(), fbb.GetSize()); - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - - // Publish the data to any clients who requested notifications on this key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. - CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( - notification_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - for (; !RedisModule_ZsetRangeEndReached(notification_key); - RedisModule_ZsetRangeNext(notification_key)) { - RedisModuleString *client_channel = - RedisModule_ZsetRangeCurrentElement(notification_key, NULL); - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - fbb.GetBufferPointer(), fbb.GetSize()); - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - } - } - return RedisModule_ReplyWithSimpleString(ctx, "OK"); -} - -// RAY.TABLE_ADD: -// TableAdd_RedisCommand: the actual command handler. -// (helper) TableAdd_DoWrite: performs the write to redis state. -// (helper) TableAdd_DoPublish: performs a publish after the write. -// ChainTableAdd_RedisCommand: the same command, chain-enabled. - -int TableAdd_DoWrite(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc, - RedisModuleString **mutated_key_str) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - - RedisModuleKey *key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, - mutated_key_str); - RedisModule_StringSet(key, data); - return REDISMODULE_OK; -} - -int TableAdd_DoPublish(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - - TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str); - - if (pubsub_channel == TablePubsub::TASK) { - // Publish the task to its subscribers. - // TODO(swang): This is only necessary for legacy Ray and should be removed - // once we switch to using the new GCS API for the task table. - return PublishTaskTableAdd(ctx, id, data); - } else if (pubsub_channel != TablePubsub::NO_PUBLISH) { - // All other pubsub channels write the data back directly onto the channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); - } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); - } -} - -/// Add an entry at a key. This overwrites any existing data at the key. -/// Publishes a notification about the update to all subscribers, if a pubsub -/// channel is provided. -/// -/// This is called from a client with the command: -/// -/// RAY.TABLE_ADD -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to set. -/// \param data The data to insert at the key. -/// \return The current value at the key, or OK if there is no value. -int TableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - TableAdd_DoWrite(ctx, argv, argc, /*mutated_key_str=*/nullptr); - return TableAdd_DoPublish(ctx, argv, argc); -} - -#if RAY_USE_NEW_GCS -int ChainTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - return module.ChainReplicate(ctx, argv, argc, /*node_func=*/TableAdd_DoWrite, - /*tail_func=*/TableAdd_DoPublish); -} -#endif - -int TableAppend_DoWrite(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc, - RedisModuleString **mutated_key_str) { - if (argc < 5 || argc > 6) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - RedisModuleString *index_str = nullptr; - if (argc == 6) { - index_str = argv[5]; - } - - // Set the keys in the table. - RedisModuleKey *key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, - mutated_key_str); - // Determine the index at which the data should be appended. If no index is - // requested, then is the current length of the log. - size_t index = RedisModule_ValueLength(key); - if (index_str != nullptr) { - // Parse the requested index. - long long requested_index; - RAY_CHECK(RedisModule_StringToLongLong(index_str, &requested_index) == - REDISMODULE_OK); - RAY_CHECK(requested_index >= 0); - index = static_cast(requested_index); - } - // Only perform the append if the requested index matches the current length - // of the log, or if no index was requested. - if (index == RedisModule_ValueLength(key)) { - // The requested index matches the current length of the log or no index - // was requested. Perform the append. - int flags = REDISMODULE_ZADD_NX; - RedisModule_ZsetAdd(key, index, data, &flags); - // Check that we actually add a new entry during the append. This is only - // necessary since we implement the log with a sorted set, so all entries - // must be unique, or else we will have gaps in the log. - // TODO(rkn): We need to get rid of this uniqueness requirement. We can - // easily have multiple log events with the same message. - RAY_CHECK(flags == REDISMODULE_ZADD_ADDED) << "Appended a duplicate entry"; - return REDISMODULE_OK; - } else { - // The requested index did not match the current length of the log. Return - // an error message as a string. - static const char *reply = "ERR entry exists"; - RedisModule_ReplyWithStringBuffer(ctx, reply, strlen(reply)); - return REDISMODULE_ERR; - } -} - -int TableAppend_DoPublish(RedisModuleCtx *ctx, - RedisModuleString **argv, - int /*argc*/) { - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *data = argv[4]; - // Publish a message on the requested pubsub channel if necessary. - TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str); - if (pubsub_channel != TablePubsub::NO_PUBLISH) { - // All other pubsub channels write the data back directly onto the - // channel. - return PublishTableAdd(ctx, pubsub_channel_str, id, data); - } else { - return RedisModule_ReplyWithSimpleString(ctx, "OK"); - } -} - -/// Append an entry to the log stored at a key. Publishes a notification about -/// the update to all subscribers, if a pubsub channel is provided. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_APPEND -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to append to. -/// \param data The data to append to the key. -/// \param index If this is set, then the data must be appended at this index. -/// If the current log is shorter or longer than the requested index, -/// then the append will fail and an error message will be returned as a -/// string. -/// \return OK if the append succeeds, or an error message string if the append -/// fails. -int TableAppend_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - const int status = TableAppend_DoWrite(ctx, argv, argc, - /*mutated_key_str=*/nullptr); - if (status) { - return status; - } - return TableAppend_DoPublish(ctx, argv, argc); -} - -#if RAY_USE_NEW_GCS -int ChainTableAppend_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - return module.ChainReplicate(ctx, argv, argc, - /*node_func=*/TableAppend_DoWrite, - /*tail_func=*/TableAppend_DoPublish); -} -#endif - -/// A helper function to create and finish a GcsTableEntry, based on the -/// current value or values at the given key. -void TableEntryToFlatbuf(RedisModuleKey *table_key, - RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { - auto key_type = RedisModule_KeyType(table_key); - switch (key_type) { - case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. - size_t data_len = 0; - char *data_buf = - RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(&data, 1)); - fbb.Finish(message); - } break; - case REDISMODULE_KEYTYPE_ZSET: { - // Build the flatbuffer from the set of log entries. - RAY_CHECK(RedisModule_ZsetFirstInScoreRange( - table_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1) == REDISMODULE_OK); - std::vector> data; - for (; !RedisModule_ZsetRangeEndReached(table_key); - RedisModule_ZsetRangeNext(table_key)) { - data.push_back(RedisStringToFlatbuf( - fbb, RedisModule_ZsetRangeCurrentElement(table_key, NULL))); - } - auto message = CreateGcsTableEntry(fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(data)); - fbb.Finish(message); - } break; - case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsTableEntry( - fbb, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector( - std::vector>())); - fbb.Finish(message); - } break; - default: - RAY_LOG(FATAL) << "Invalid Redis type during lookup: " << key_type; - } -} - -/// Lookup the current value or values at a key. Returns the current value or -/// values at the key. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_LOOKUP -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. This field is unused for lookups. -/// \param id The ID of the key to lookup. -/// \return nil if the key is empty, the current value if the key type is a -/// string, or an array of the current values if the key type is a set. -int TableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 4) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - - // Lookup the data at the key. - RedisModuleKey *table_key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ); - if (table_key == nullptr) { - RedisModule_ReplyWithNull(ctx); - } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - TableEntryToFlatbuf(table_key, id, fbb); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); - } - return REDISMODULE_OK; -} - -/// Request notifications for changes to a key. Returns the current value or -/// values at the key. Notifications will be sent to the requesting client for -/// every subsequent TABLE_ADD to the key. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_REQUEST_NOTIFICATIONS -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. When publishing to a specific -/// client, the channel name should be :. -/// \param id The ID of the key to publish notifications for. -/// \param client_id The ID of the client that is being notified. -/// \return nil if the key is empty, the current value if the key type is a -/// string, or an array of the current values if the key type is a set. -int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *client_id = argv[4]; - RedisModuleString *client_channel = - FormatPubsubChannel(ctx, pubsub_channel_str, client_id); - - // Add this client to the set of clients that should be notified when there - // are changes to the key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - CHECK_ERROR(RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL), - "ZsetAdd failed."); - - // Lookup the current value at the key. - RedisModuleKey *table_key = - OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ); - // Publish the current value at the key to the client that is requesting - // notifications. An empty notification will be published if the key is - // empty. - flatbuffers::FlatBufferBuilder fbb; - TableEntryToFlatbuf(table_key, id, fbb); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); - - return RedisModule_ReplyWithNull(ctx); -} - -/// Cancel notifications for changes to a key. The client will no longer -/// receive notifications for this key. This does not check if the client -/// first requested notifications before canceling them. -/// -/// This is called from a client with the command: -// -/// RAY.TABLE_CANCEL_NOTIFICATIONS -/// -/// -/// \param table_prefix The prefix string for keys in this table. -/// \param pubsub_channel The pubsub channel name that notifications for -/// this key should be published to. If publishing to a specific client, -/// then the channel name should be :. -/// \param id The ID of the key to publish notifications for. -/// \param client_id The ID of the client to cancel notifications for. -/// \return OK. -int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *pubsub_channel_str = argv[2]; - RedisModuleString *id = argv[3]; - RedisModuleString *client_id = argv[4]; - RedisModuleString *client_channel = - FormatPubsubChannel(ctx, pubsub_channel_str, client_id); - - // Remove this client from the set of clients that should be notified when - // there are changes to the key. - RedisModuleKey *notification_key = OpenBroadcastKey( - ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { - RAY_CHECK(RedisModule_ZsetRem(notification_key, client_channel, NULL) == - REDISMODULE_OK); - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -bool is_nil(const std::string &data) { - RAY_CHECK(data.size() == kUniqueIDSize); - const uint8_t *d = reinterpret_cast(data.data()); - for (int i = 0; i < kUniqueIDSize; ++i) { - if (d[i] != 255) { - return false; - } - } - return true; -} - -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key = OpenPrefixedKey(ctx, prefix_str, id, - REDISMODULE_READ | REDISMODULE_WRITE); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = flatbuffers::GetMutableRoot( - reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - if (!is_nil(update->test_scheduler_id()->str())) { - do_update = - do_update && - update->test_scheduler_id()->str() == data->scheduler_id()->str(); - } - - if (do_update) { - RAY_CHECK(data->mutate_scheduling_state(update->update_state())); - } - RAY_CHECK(data->mutate_updated(do_update)); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - -/** - * Add a new entry to the object table or update an existing one. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_ADD - * - * @param object_id A string representing the object ID. - * @param data_size An integer which is the object size in bytes. - * @param hash_string A string which is a hash of the object. - * @param manager A string which represents the manager ID of the plasma manager - * that has the object. - * @return OK if the operation was successful. If the same object_id is already - * present with a different hash value, the entry is still added, but - * an error with string "hash mismatch" is returned. - */ -int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *object_id = argv[1]; - RedisModuleString *data_size = argv[2]; - RedisModuleString *new_hash = argv[3]; - RedisModuleString *manager = argv[4]; - - long long data_size_value; - if (RedisModule_StringToLongLong(data_size, &data_size_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "data_size must be integer"); - } - - /* Set the fields in the object info table. */ - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - - /* Check if this object was already registered and if the hashes agree. */ - bool hash_mismatch = false; - if (RedisModule_KeyType(key) != REDISMODULE_KEYTYPE_EMPTY) { - RedisModuleString *existing_hash; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "hash", &existing_hash, - NULL); - /* The existing hash may be NULL even if the key is present because a call - * to RAY.RESULT_TABLE_ADD may have already created the key. */ - if (existing_hash != NULL) { - /* Check whether the new hash value matches the old one. If not, we will - * later return the "hash mismatch" error. */ - hash_mismatch = (RedisModule_StringCompare(existing_hash, new_hash) != 0); - } - } - - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "hash", new_hash, NULL); - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "data_size", data_size, - NULL); - - /* Add the location in the object location table. */ - RedisModuleKey *table_key; - table_key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - - /* Sets are not implemented yet, so we use ZSETs instead. */ - RedisModule_ZsetAdd(table_key, 0.0, manager, NULL); - - RedisModuleString *bcast_client_str = - RedisModule_CreateString(ctx, OBJECT_BCAST, strlen(OBJECT_BCAST)); - bool success = PublishObjectNotification(ctx, bcast_client_str, object_id, - data_size, table_key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH BCAST unsuccessful"); - } - - /* Get the zset of clients that requested a notification about the - * availability of this object. */ - RedisModuleKey *object_notification_key = - OpenPrefixedKey(ctx, OBJECT_NOTIFICATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - /* If the zset exists, initialize the key to iterate over the zset. */ - if (RedisModule_KeyType(object_notification_key) != - REDISMODULE_KEYTYPE_EMPTY) { - CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( - object_notification_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - /* Iterate over the list of clients that requested notifiations about the - * availability of this object, and publish notifications to their object - * notification channels. */ - - do { - RedisModuleString *client_id = - RedisModule_ZsetRangeCurrentElement(object_notification_key, NULL); - /* TODO(rkn): Some computation could be saved by batching the string - * constructions in the multiple calls to PublishObjectNotification - * together. */ - bool success = PublishObjectNotification(ctx, client_id, object_id, - data_size, table_key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - } while (RedisModule_ZsetRangeNext(object_notification_key)); - /* Now that the clients have been notified, remove the zset of clients - * waiting for notifications. */ - CHECK_ERROR(RedisModule_DeleteKey(object_notification_key), - "Unable to delete zset key."); - } - - if (hash_mismatch) { - return RedisModule_ReplyWithError(ctx, "hash mismatch"); - } else { - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; - } -} - -/** - * Remove a manager from a location entry in the object table. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_REMOVE - * - * @param object_id A string representing the object ID. - * @param manager A string which represents the manager ID of the plasma manager - * to remove. - * @return OK if the operation was successful or an error with string - * "object not found" if the entry for the object_id doesn't exist. The - * operation is counted as a success if the manager was already not in - * the entry. - */ -int ObjectTableRemove_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 3) { - return RedisModule_WrongArity(ctx); - } - - RedisModuleString *object_id = argv[1]; - RedisModuleString *manager = argv[2]; - - /* Remove the location from the object location table. */ - RedisModuleKey *table_key; - table_key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(table_key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "object not found"); - } - - RedisModule_ZsetRem(table_key, manager, NULL); - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -/** - * Request notifications about the presence of some object IDs. This command - * takes a list of object IDs. For each object ID, the reply will be the list - * of plasma managers that contain the object. If the list of plasma managers - * is currently nonempty, then the reply will happen immediately. Else, the - * reply will come later, on the first invocation of `RAY.OBJECT_TABLE_ADD` - * following this call. - * - * This is called from a client with the command: - * - * RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS - * ... - * - * @param client_id The ID of the client that is requesting the notifications. - * @param object_id(n) The ID of the nth object ID that is passed to this - * command. This command can take any number of object IDs. - * @return OK if the operation was successful. - */ -int ObjectTableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 3) { - return RedisModule_WrongArity(ctx); - } - - /* The first argument is the client ID. The other arguments are object IDs. */ - RedisModuleString *client_id = argv[1]; - - /* Loop over the object ID arguments to this command. */ - for (int i = 2; i < argc; ++i) { - RedisModuleString *object_id = argv[i]; - RedisModuleKey *key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, - object_id, REDISMODULE_READ); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY || - RedisModule_ValueLength(key) == 0) { - /* This object ID is currently not present, so make a note that this - * client should be notified when this object ID becomes available. */ - RedisModuleKey *object_notification_key = - OpenPrefixedKey(ctx, OBJECT_NOTIFICATION_PREFIX, object_id, - REDISMODULE_READ | REDISMODULE_WRITE); - /* Add this client to the list of clients that will be notified when this - * object becomes available. */ - CHECK_ERROR( - RedisModule_ZsetAdd(object_notification_key, 0.0, client_id, NULL), - "ZsetAdd failed."); - } else { - /* Publish a notification to the client's object notification channel. */ - /* Extract the data_size first. */ - RedisModuleKey *object_info_key; - object_info_key = - OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithError(ctx, "requested object not found"); - } - RedisModuleString *existing_data_size; - RedisModule_HashGet(object_info_key, REDISMODULE_HASH_CFIELDS, - "data_size", &existing_data_size, NULL); - if (existing_data_size == NULL) { - return RedisModule_ReplyWithError(ctx, - "no data_size field in object info"); - } - - bool success = PublishObjectNotification(ctx, client_id, object_id, - existing_data_size, key); - if (!success) { - /* The publish failed somehow. */ - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - } - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; -} - -int ObjectInfoSubscribe_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - REDISMODULE_NOT_USED(argv); - REDISMODULE_NOT_USED(argc); - return REDISMODULE_OK; -} - -/** - * Add a new entry to the result table or update an existing one. - * - * This is called from a client with the command: - * - * RAY.RESULT_TABLE_ADD - * - * @param object_id A string representing the object ID. - * @param task_id A string representing the task ID of the task that produced - * the object. - * @param is_put An integer that is 1 if the object was created through ray.put - * and 0 if created by return value. - * @return OK if the operation was successful. - */ -int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 4) { - return RedisModule_WrongArity(ctx); - } - - /* Set the task ID under field "task" in the object info table. */ - RedisModuleString *object_id = argv[1]; - RedisModuleString *task_id = argv[2]; - RedisModuleString *is_put = argv[3]; - - /* Check to make sure the is_put field was a 0 or a 1. */ - long long is_put_integer; - if ((RedisModule_StringToLongLong(is_put, &is_put_integer) != - REDISMODULE_OK) || - (is_put_integer != 0 && is_put_integer != 1)) { - return RedisModule_ReplyWithError( - ctx, "The is_put field must be either a 0 or a 1."); - } - - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_WRITE); - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "task", task_id, "is_put", - is_put, NULL); - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - - return REDISMODULE_OK; -} - -/** - * Reply with information about a task ID. This is used by - * RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET. - * - * @param ctx The Redis context. - * @param task_id The task ID of the task to reply about. - * @param updated A boolean representing whether the task was updated during - * this operation. This field is only used for - * RAY.TASK_TABLE_TEST_AND_UPDATE operations. - * @return NIL if the task ID is not in the task table. An error if the task ID - * is in the task table but the appropriate fields are not there, and - * an array of the task scheduling state, the local scheduler ID, and - * the task spec for the task otherwise. - */ -int ReplyWithTask(RedisModuleCtx *ctx, - RedisModuleString *task_id, - bool updated) { - RedisModuleKey *key = - OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_READ); - - if (RedisModule_KeyType(key) != REDISMODULE_KEYTYPE_EMPTY) { - /* If the key exists, look up the fields and return them in an array. */ - RedisModuleString *state = NULL; - RedisModuleString *local_scheduler_id = NULL; - RedisModuleString *execution_dependencies = NULL; - RedisModuleString *task_spec = NULL; - RedisModuleString *spillback_count = NULL; - RedisModule_HashGet( - key, REDISMODULE_HASH_CFIELDS, "state", &state, "local_scheduler_id", - &local_scheduler_id, "execution_dependencies", &execution_dependencies, - "TaskSpec", &task_spec, "spillback_count", &spillback_count, NULL); - if (state == NULL || local_scheduler_id == NULL || - execution_dependencies == NULL || task_spec == NULL || - spillback_count == NULL) { - /* We must have either all fields or no fields. */ - return RedisModule_ReplyWithError( - ctx, "Missing fields in the task table entry"); - } - - long long state_integer; - long long spillback_count_val; - if ((RedisModule_StringToLongLong(state, &state_integer) != - REDISMODULE_OK) || - (state_integer < 0) || - (RedisModule_StringToLongLong(spillback_count, &spillback_count_val) != - REDISMODULE_OK) || - (spillback_count_val < 0)) { - return RedisModule_ReplyWithError( - ctx, "Found invalid scheduling state or spillback count."); - } - - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateTaskReply( - fbb, RedisStringToFlatbuf(fbb, task_id), state_integer, - RedisStringToFlatbuf(fbb, local_scheduler_id), - RedisStringToFlatbuf(fbb, execution_dependencies), - RedisStringToFlatbuf(fbb, task_spec), spillback_count_val, updated); - fbb.Finish(message); - - RedisModuleString *reply = RedisModule_CreateString( - ctx, (char *) fbb.GetBufferPointer(), fbb.GetSize()); - RedisModule_ReplyWithString(ctx, reply); - } else { - /* If the key does not exist, return nil. */ - RedisModule_ReplyWithNull(ctx); - } - - return REDISMODULE_OK; -} - -/** - * Lookup an entry in the result table. - * - * This is called from a client with the command: - * - * RAY.RESULT_TABLE_LOOKUP - * - * @param object_id A string representing the object ID. - * @return NIL if the object ID is not in the result table. Otherwise, this - * returns a ResultTableReply flatbuffer. - */ -int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - /* Get the task ID under field "task" in the object info table. */ - RedisModuleString *object_id = argv[1]; - - RedisModuleKey *key; - key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ); - - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithNull(ctx); - } - - RedisModuleString *task_id; - RedisModuleString *is_put; - RedisModuleString *data_size; - RedisModuleString *hash; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task", &task_id, "is_put", - &is_put, "data_size", &data_size, "hash", &hash, NULL); - - if (task_id == NULL || is_put == NULL) { - return RedisModule_ReplyWithNull(ctx); - } - - /* Check to make sure the is_put field was a 0 or a 1. */ - long long is_put_integer; - if (RedisModule_StringToLongLong(is_put, &is_put_integer) != REDISMODULE_OK || - (is_put_integer != 0 && is_put_integer != 1)) { - return RedisModule_ReplyWithError( - ctx, "The is_put field must be either a 0 or a 1."); - } - - /* Make and return the flatbuffer reply. */ - flatbuffers::FlatBufferBuilder fbb; - long long data_size_value; - - if (data_size == NULL) { - data_size_value = -1; - } else { - RedisModule_StringToLongLong(data_size, &data_size_value); - RAY_CHECK(RedisModule_StringToLongLong(data_size, &data_size_value) == - REDISMODULE_OK); - } - - flatbuffers::Offset hash_str; - if (hash == NULL) { - hash_str = fbb.CreateString("", strlen("")); - } else { - hash_str = RedisStringToFlatbuf(fbb, hash); - } - - flatbuffers::Offset message = - CreateResultTableReply(fbb, RedisStringToFlatbuf(fbb, task_id), - bool(is_put_integer), data_size_value, hash_str); - - fbb.Finish(message); - RedisModuleString *reply = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - RedisModule_ReplyWithString(ctx, reply); - - return REDISMODULE_OK; -} - -int TaskTableWrite(RedisModuleCtx *ctx, - RedisModuleString *task_id, - RedisModuleString *state, - RedisModuleString *local_scheduler_id, - RedisModuleString *execution_dependencies, - RedisModuleString *spillback_count, - RedisModuleString *task_spec) { - /* Extract the scheduling state. */ - long long state_value; - if (RedisModule_StringToLongLong(state, &state_value) != REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "scheduling state must be integer"); - } - - long long spillback_count_value; - if (RedisModule_StringToLongLong(spillback_count, &spillback_count_value) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "spillback count must be integer"); - } - /* Add the task to the task table. If no spec was provided, get the existing - * spec out of the task table so we can publish it. */ - RedisModuleString *existing_task_spec = NULL; - RedisModuleKey *key = - OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_WRITE); - if (task_spec == NULL) { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, - "local_scheduler_id", local_scheduler_id, - "execution_dependencies", execution_dependencies, - "spillback_count", spillback_count, NULL); - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "TaskSpec", - &existing_task_spec, NULL); - if (existing_task_spec == NULL) { - return RedisModule_ReplyWithError( - ctx, "Cannot update a task that doesn't exist yet"); - } - } else { - RedisModule_HashSet( - key, REDISMODULE_HASH_CFIELDS, "state", state, "local_scheduler_id", - local_scheduler_id, "execution_dependencies", execution_dependencies, - "TaskSpec", task_spec, "spillback_count", spillback_count, NULL); - } - - if (static_cast(state_value) == TaskStatus::WAITING || - static_cast(state_value) == TaskStatus::SCHEDULED) { - /* Build the PUBLISH topic and message for task table subscribers. The - * topic is a string in the format - * "TASK_PREFIX::". The message is a serialized - * SubscribeToTasksReply flatbuffer object. */ - RedisModuleString *publish_topic = RedisString_Format( - ctx, "%s%S:%S", TASK_PREFIX, local_scheduler_id, state); - - /* Construct the flatbuffers object for the payload. */ - flatbuffers::FlatBufferBuilder fbb; - /* Use the old task spec if the current one is NULL. */ - RedisModuleString *task_spec_to_use; - if (task_spec != NULL) { - task_spec_to_use = task_spec; - } else { - task_spec_to_use = existing_task_spec; - } - /* Create the flatbuffers message. */ - auto message = CreateTaskReply( - fbb, RedisStringToFlatbuf(fbb, task_id), state_value, - RedisStringToFlatbuf(fbb, local_scheduler_id), - RedisStringToFlatbuf(fbb, execution_dependencies), - RedisStringToFlatbuf(fbb, task_spec_to_use), spillback_count_value, - true); // The updated field is not used. - fbb.Finish(message); - - RedisModuleString *publish_message = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); - - /* See how many clients received this publish. */ - long long num_clients = RedisModule_CallReplyInteger(reply); - RAY_CHECK(num_clients <= 1) << "Published to " << num_clients - << " clients."; - - if (reply == NULL) { - return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); - } - - if (num_clients == 0) { - /* This reply will be received by redis_task_table_update_callback or - * redis_task_table_add_task_callback in redis.cc, which will then reissue - * the command. */ - return RedisModule_ReplyWithError(ctx, - "No subscribers received message."); - } - } - - RedisModule_ReplyWithSimpleString(ctx, "OK"); - - return REDISMODULE_OK; -} - -/** - * Add a new entry to the task table. This will overwrite any existing entry - * with the same task ID. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_ADD - * - * - * @param task_id A string that is the ID of the task. - * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). - * @param local_scheduler_id A string that is the ray client ID of the - * associated local scheduler, if any. - * @param execution_dependencies A string that is the list of execution - * dependencies. - * @param task_spec A string that is the specification of the task, which can - * be cast to a `task_spec`. - * @return OK if the operation was successful. - */ -int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 7) { - return RedisModule_WrongArity(ctx); - } - - return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5], - argv[6]); -} - -/** - * Update an entry in the task table. This does not update the task - * specification in the table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_UPDATE - * - * - * @param task_id A string that is the ID of the task. - * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). - * @param ray_client_id A string that is the ray client ID of the associated - * local scheduler, if any. - * @param execution_dependencies A string that is the list of execution - * dependencies. - * @return OK if the operation was successful. - */ -int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 6) { - return RedisModule_WrongArity(ctx); - } - - return TaskTableWrite(ctx, argv[1], argv[2], argv[3], argv[4], argv[5], NULL); -} - -/** - * Test and update an entry in the task table if the current value matches the - * test value bitmask. This does not update the task specification in the - * table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_TEST_AND_UPDATE - * - * - * @param task_id A string that is the ID of the task. - * @param test_state_bitmask A string that is the test bitmask for the - * scheduling state. The update happens if and only if the current - * scheduling state AND-ed with the bitmask is greater than 0. - * @param state A string that is the scheduling state (a scheduling_state enum - * instance) to update the task entry with. - * @param ray_client_id A string that is the ray client ID of the associated - * local scheduler, if any, to update the task entry with. - * @param test_local_scheduler_id A string to test the local scheduler ID. If - * provided, and if the current local scheduler ID does not match it, - * then the update does not happen. - * @return Returns the task entry as a TaskReply. The reply will reflect the - * update, if it happened. - */ -int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc < 5 || argc > 6) { - return RedisModule_WrongArity(ctx); - } - /* If a sixth argument was provided, then we should also test the current - * local scheduler ID. */ - bool test_local_scheduler = (argc == 6); - - RedisModuleString *task_id = argv[1]; - RedisModuleString *test_state = argv[2]; - RedisModuleString *update_state = argv[3]; - RedisModuleString *local_scheduler_id = argv[4]; - - RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, task_id, - REDISMODULE_READ | REDISMODULE_WRITE); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { - return RedisModule_ReplyWithNull(ctx); - } - - /* If the key exists, look up the fields and return them in an array. */ - RedisModuleString *current_state = NULL; - RedisModuleString *current_local_scheduler_id = NULL; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", ¤t_state, - "local_scheduler_id", ¤t_local_scheduler_id, NULL); - - long long current_state_integer; - if (RedisModule_StringToLongLong(current_state, ¤t_state_integer) != - REDISMODULE_OK) { - return RedisModule_ReplyWithError(ctx, "current_state must be integer"); - } - - if (current_state_integer < 0) { - return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state."); - } - long long test_state_bitmask; - int status = RedisModule_StringToLongLong(test_state, &test_state_bitmask); - if (status != REDISMODULE_OK) { - return RedisModule_ReplyWithError( - ctx, "Invalid test value for scheduling state"); - } - - bool update = false; - if (current_state_integer & test_state_bitmask) { - if (test_local_scheduler) { - /* A test local scheduler ID was provided. Test whether it is equal to - * the current local scheduler ID before performing the update. */ - RedisModuleString *test_local_scheduler_id = argv[5]; - if (RedisModule_StringCompare(current_local_scheduler_id, - test_local_scheduler_id) == 0) { - /* If the current local scheduler ID does matches the test ID, then - * perform the update. */ - update = true; - } - } else { - /* No test local scheduler ID was provided. Perform the update. */ - update = true; - } - } - - /* If the scheduling state and local scheduler ID tests passed, then perform - * the update. */ - if (update) { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", update_state, - "local_scheduler_id", local_scheduler_id, NULL); - } - - /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, task_id, update); -} - -/** - * Get an entry from the task table. - * - * This is called from a client with the command: - * - * RAY.TASK_TABLE_GET - * - * @param task_id A string of the task ID to look up. - * @return An array of strings representing the task fields in the following - * order: 1) (integer) scheduling state 2) (string) associated local - * scheduler ID, if any 3) (string) the task specification, which can be - * cast to a task_spec. If the task ID is not in the table, returns nil. - */ -int TaskTableGet_RedisCommand(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - RedisModule_AutoMemory(ctx); - - if (argc != 2) { - return RedisModule_WrongArity(ctx); - } - - /* Construct a reply by getting the task from the task ID. */ - return ReplyWithTask(ctx, argv[1], false); -} - -extern "C" { - -/* This function must be present on each Redis module. It is used in order to - * register the commands into the Redis server. */ -int RedisModule_OnLoad(RedisModuleCtx *ctx, - RedisModuleString **argv, - int argc) { - REDISMODULE_NOT_USED(argv); - REDISMODULE_NOT_USED(argc); - - if (RedisModule_Init(ctx, "ray", 1, REDISMODULE_APIVER_1) == - REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.connect", Connect_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.disconnect", Disconnect_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_add", TableAdd_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_append", - TableAppend_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_lookup", - TableLookup_RedisCommand, "readonly", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", - TableRequestNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_cancel_notifications", - TableCancelNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_lookup", - ObjectTableLookup_RedisCommand, "readonly", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_add", - ObjectTableAdd_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_remove", - ObjectTableRemove_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_table_request_notifications", - ObjectTableRequestNotifications_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.object_info_subscribe", - ObjectInfoSubscribe_RedisCommand, "pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.result_table_add", - ResultTableAdd_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.result_table_lookup", - ResultTableLookup_RedisCommand, "readonly", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_add", - TaskTableAddTask_RedisCommand, "write pubsub", - 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_update", - TaskTableUpdate_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_test_and_update", - TaskTableTestAndUpdate_RedisCommand, - "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - - if (RedisModule_CreateCommand(ctx, "ray.task_table_get", - TaskTableGet_RedisCommand, "readonly", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - -#if RAY_USE_NEW_GCS - // Chain-enabled commands that depend on ray-project/credis. - if (RedisModule_CreateCommand(ctx, "ray.chain.table_add", - ChainTableAdd_RedisCommand, "write pubsub", 0, - 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.chain.table_append", - ChainTableAppend_RedisCommand, "write pubsub", - 0, 0, 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } -#endif - - return REDISMODULE_OK; -} - -} /* extern "C" */ diff --git a/src/common/shims/windows/getopt.c b/src/common/shims/windows/getopt.c deleted file mode 100644 index d9c4ae583307..000000000000 --- a/src/common/shims/windows/getopt.c +++ /dev/null @@ -1,69 +0,0 @@ -/* http://stackoverflow.com/a/17195644/541686 */ - -#include -#include - -int opterr = 1, /* if error message should be printed */ - optind = 1, /* index into parent argv vector */ - optopt, /* character checked for validity */ - optreset; /* reset getopt */ -char *optarg; /* argument associated with option */ - -#define BADCH (int) '?' -#define BADARG (int) ':' -#define EMSG "" - -/* -* getopt -- -* Parse argc/argv argument vector. -*/ -int getopt(int nargc, char *const nargv[], const char *ostr) { - static char *place = EMSG; /* option letter processing */ - const char *oli; /* option letter list index */ - - if (optreset || !*place) { /* update scanning pointer */ - optreset = 0; - if (optind >= nargc || *(place = nargv[optind]) != '-') { - place = EMSG; - return (-1); - } - if (place[1] && *++place == '-') { /* found "--" */ - ++optind; - place = EMSG; - return (-1); - } - } /* option letter okay? */ - if ((optopt = (int) *place++) == (int) ':' || !(oli = strchr(ostr, optopt))) { - /* - * if the user didn't specify '-' as an option, - * assume it means -1. - */ - if (optopt == (int) '-') - return (-1); - if (!*place) - ++optind; - if (opterr && *ostr != ':') - (void) printf("illegal option -- %c\n", optopt); - return (BADCH); - } - if (*++oli != ':') { /* don't need argument */ - optarg = NULL; - if (!*place) - ++optind; - } else { /* need an argument */ - if (*place) /* no white space */ - optarg = place; - else if (nargc <= ++optind) { /* no arg */ - place = EMSG; - if (*ostr == ':') - return (BADARG); - if (opterr) - (void) printf("option requires an argument -- %c\n", optopt); - return (BADCH); - } else /* white space */ - optarg = nargv[optind]; - place = EMSG; - ++optind; - } - return (optopt); /* dump back option letter */ -} diff --git a/src/common/shims/windows/getopt.h b/src/common/shims/windows/getopt.h deleted file mode 100644 index 1870fb87f793..000000000000 --- a/src/common/shims/windows/getopt.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef GETOPT_H -#define GETOPT_H - -#endif /* GETOPT_H */ diff --git a/src/common/shims/windows/msg.c b/src/common/shims/windows/msg.c deleted file mode 100644 index 5142c1aadf2e..000000000000 --- a/src/common/shims/windows/msg.c +++ /dev/null @@ -1,208 +0,0 @@ -#include - -int socketpair(int domain, int type, int protocol, int sv[2]) { - if ((domain != AF_UNIX && domain != AF_INET) || type != SOCK_STREAM) { - return INVALID_SOCKET; - } - SOCKET sockets[2]; - int r = dumb_socketpair(sockets); - sv[0] = (int) sockets[0]; - sv[1] = (int) sockets[1]; - return r; -} - -#pragma comment(lib, "IPHlpAPI.lib") - -struct _MIB_TCPROW2 { - DWORD dwState, dwLocalAddr, dwLocalPort, dwRemoteAddr, dwRemotePort, - dwOwningPid; - enum _TCP_CONNECTION_OFFLOAD_STATE dwOffloadState; -}; - -struct _MIB_TCPTABLE2 { - DWORD dwNumEntries; - struct _MIB_TCPROW2 table[1]; -}; - -DECLSPEC_IMPORT ULONG WINAPI GetTcpTable2(struct _MIB_TCPTABLE2 *TcpTable, - PULONG SizePointer, - BOOL Order); - -static DWORD getsockpid(SOCKET client) { - /* http://stackoverflow.com/a/25431340 */ - DWORD pid = 0; - - struct sockaddr_in Server = {0}; - int ServerSize = sizeof(Server); - - struct sockaddr_in Client = {0}; - int ClientSize = sizeof(Client); - - if ((getsockname(client, (struct sockaddr *) &Server, &ServerSize) == 0) && - (getpeername(client, (struct sockaddr *) &Client, &ClientSize) == 0)) { - struct _MIB_TCPTABLE2 *TcpTable = NULL; - ULONG TcpTableSize = 0; - ULONG result; - do { - result = GetTcpTable2(TcpTable, &TcpTableSize, TRUE); - if (result != ERROR_INSUFFICIENT_BUFFER) { - break; - } - free(TcpTable); - TcpTable = (struct _MIB_TCPTABLE2 *) malloc(TcpTableSize); - } while (TcpTable != NULL); - - if (result == NO_ERROR) { - for (DWORD dw = 0; dw < TcpTable->dwNumEntries; ++dw) { - struct _MIB_TCPROW2 *row = &(TcpTable->table[dw]); - if ((row->dwState == 5 /* MIB_TCP_STATE_ESTAB */) && - (row->dwLocalAddr == Client.sin_addr.s_addr) && - ((row->dwLocalPort & 0xFFFF) == Client.sin_port) && - (row->dwRemoteAddr == Server.sin_addr.s_addr) && - ((row->dwRemotePort & 0xFFFF) == Server.sin_port)) { - pid = row->dwOwningPid; - break; - } - } - } - - free(TcpTable); - } - - return pid; -} - -ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags) { - ssize_t result = -1; - struct cmsghdr *header = CMSG_FIRSTHDR(msg); - if (header->cmsg_level == SOL_SOCKET && header->cmsg_type == SCM_RIGHTS) { - /* We're trying to send over a handle of some kind. - * We have to look up which process we're communicating with, - * open a handle to it, and then duplicate our handle into it. - * However, the first two steps cannot be done atomically. - * Therefore, this code HAS A RACE CONDITIONS and is therefore NOT SECURE. - * In the absense of a malicious actor, though, it is exceedingly unlikely - * that the child process closes AND that its process ID is reassigned - * to another existing process. - */ - struct msghdr const old_msg = *msg; - int *const pfd = (int *) CMSG_DATA(header); - msg->msg_control = NULL; - msg->msg_controllen = 0; - WSAPROTOCOL_INFO protocol_info = {0}; - BOOL const is_socket = !!FDAPI_GetSocketStatePtr(*pfd); - DWORD const target_pid = getsockpid(sockfd); - HANDLE target_process = NULL; - if (target_pid) { - if (!is_socket) { - /* This is a regular handle... fit it into the same struct */ - target_process = OpenProcess(PROCESS_DUP_HANDLE, FALSE, target_pid); - if (target_process) { - if (DuplicateHandle(GetCurrentProcess(), (HANDLE)(intptr_t) *pfd, - target_process, (HANDLE *) &protocol_info, 0, - TRUE, DUPLICATE_SAME_ACCESS)) { - result = 0; - } - } - } else { - /* This is a socket... */ - result = FDAPI_WSADuplicateSocket(*pfd, target_pid, &protocol_info); - } - } - if (result == 0) { - int const nbufs = msg->dwBufferCount + 1; - WSABUF *const bufs = - (struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs); - bufs[0].buf = (char *) &protocol_info; - bufs[0].len = sizeof(protocol_info); - memcpy(&bufs[1], msg->lpBuffers, - msg->dwBufferCount * sizeof(*msg->lpBuffers)); - DWORD nb; - msg->lpBuffers = bufs; - msg->dwBufferCount = nbufs; - GUID const wsaid_WSASendMsg = { - 0xa441e712, - 0x754f, - 0x43ca, - {0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d}}; - typedef INT PASCAL WSASendMsg_t( - SOCKET s, LPWSAMSG lpMsg, DWORD dwFlags, LPDWORD lpNumberOfBytesSent, - LPWSAOVERLAPPED lpOverlapped, - LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine); - WSASendMsg_t *WSASendMsg = NULL; - result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER, - &wsaid_WSASendMsg, sizeof(wsaid_WSASendMsg), - &WSASendMsg, sizeof(WSASendMsg), &nb, NULL, 0); - if (result == 0) { - result = (*WSASendMsg)(sockfd, msg, flags, &nb, NULL, NULL) == 0 - ? (ssize_t)(nb - sizeof(protocol_info)) - : 0; - } - } - if (result != 0 && target_process && !is_socket) { - /* we failed to send the handle, and it needs cleaning up! */ - HANDLE duplicated_back = NULL; - if (DuplicateHandle(target_process, *(HANDLE *) &protocol_info, - GetCurrentProcess(), &duplicated_back, 0, FALSE, - DUPLICATE_CLOSE_SOURCE)) { - CloseHandle(duplicated_back); - } - } - if (target_process) { - CloseHandle(target_process); - } - *msg = old_msg; - } - return result; -} - -ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags) { - int result = -1; - struct cmsghdr *header = CMSG_FIRSTHDR(msg); - if (msg->msg_controllen && - flags == 0 /* We can't send flags on Windows... */) { - struct msghdr const old_msg = *msg; - msg->msg_control = NULL; - msg->msg_controllen = 0; - WSAPROTOCOL_INFO protocol_info = {0}; - int const nbufs = msg->dwBufferCount + 1; - WSABUF *const bufs = - (struct _WSABUF *) _alloca(sizeof(*msg->lpBuffers) * nbufs); - bufs[0].buf = (char *) &protocol_info; - bufs[0].len = sizeof(protocol_info); - memcpy(&bufs[1], msg->lpBuffers, - msg->dwBufferCount * sizeof(*msg->lpBuffers)); - typedef INT PASCAL WSARecvMsg_t( - SOCKET s, LPWSAMSG lpMsg, LPDWORD lpNumberOfBytesRecvd, - LPWSAOVERLAPPED lpOverlapped, - LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine); - WSARecvMsg_t *WSARecvMsg = NULL; - DWORD nb; - GUID const wsaid_WSARecvMsg = { - 0xf689d7c8, - 0x6f1f, - 0x436b, - {0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22}}; - result = FDAPI_WSAIoctl(sockfd, SIO_GET_EXTENSION_FUNCTION_POINTER, - &wsaid_WSARecvMsg, sizeof(wsaid_WSARecvMsg), - &WSARecvMsg, sizeof(WSARecvMsg), &nb, NULL, 0); - if (result == 0) { - result = (*WSARecvMsg)(sockfd, msg, &nb, NULL, NULL) == 0 - ? (ssize_t)(nb - sizeof(protocol_info)) - : 0; - } - if (result == 0) { - int *const pfd = (int *) CMSG_DATA(header); - if (protocol_info.iSocketType == 0 && protocol_info.iProtocol == 0) { - *pfd = *(int *) &protocol_info; - } else { - *pfd = FDAPI_WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, &protocol_info, 0, 0); - } - header->cmsg_level = SOL_SOCKET; - header->cmsg_type = SCM_RIGHTS; - } - *msg = old_msg; - } - return result; -} diff --git a/src/common/shims/windows/netdb.h b/src/common/shims/windows/netdb.h deleted file mode 100644 index 5dace165919a..000000000000 --- a/src/common/shims/windows/netdb.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef NETDB_H -#define NETDB_H - -#endif /* NETDB_H */ diff --git a/src/common/shims/windows/netinet/in.h b/src/common/shims/windows/netinet/in.h deleted file mode 100644 index a60db3e05dd6..000000000000 --- a/src/common/shims/windows/netinet/in.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef IN_H -#define IN_H - -#endif /* IN_H */ diff --git a/src/common/shims/windows/poll.h b/src/common/shims/windows/poll.h deleted file mode 100644 index 058e23adee64..000000000000 --- a/src/common/shims/windows/poll.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef POLL_H -#define POLL_H - -#endif /* POLL_H */ diff --git a/src/common/shims/windows/socketpair.c b/src/common/shims/windows/socketpair.c deleted file mode 100644 index e9fc792c15a7..000000000000 --- a/src/common/shims/windows/socketpair.c +++ /dev/null @@ -1,150 +0,0 @@ -/* socketpair.c -Copyright 2007, 2010 by Nathan C. Myers -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - The name of the author must not be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* Changes: - * 2014-02-12: merge David Woodhouse, Ger Hobbelt improvements - * git.infradead.org/users/dwmw2/openconnect.git/commitdiff/bdeefa54 - * github.com/GerHobbelt/selectable-socketpair - * always init the socks[] to -1/INVALID_SOCKET on error, both on Win32/64 - * and UNIX/other platforms - * 2013-07-18: Change to BSD 3-clause license - * 2010-03-31: - * set addr to 127.0.0.1 because win32 getsockname does not always set it. - * 2010-02-25: - * set SO_REUSEADDR option to avoid leaking some windows resource. - * Windows System Error 10049, "Event ID 4226 TCP/IP has reached - * the security limit imposed on the number of concurrent TCP connect - * attempts." Bleah. - * 2007-04-25: - * preserve value of WSAGetLastError() on all error returns. - * 2007-04-22: (Thanks to Matthew Gregan ) - * s/EINVAL/WSAEINVAL/ fix trivial compile failure - * s/socket/WSASocket/ enable creation of sockets suitable as stdin/stdout - * of a child process. - * add argument make_overlapped - */ - -#include - -#ifdef WIN32 -#include /* socklen_t, et al (MSVC20xx) */ -#include -#include -#else -#ifdef _WIN32 -#include -#include -#endif -#include -#include -#include -#endif - -#ifdef WIN32 - -/* dumb_socketpair: - * If make_overlapped is nonzero, both sockets created will be usable for - * "overlapped" operations via WSASend etc. If make_overlapped is zero, - * socks[0] (only) will be usable with regular ReadFile etc., and thus - * suitable for use as stdin or stdout of a child process. Note that the - * sockets must be closed with closesocket() regardless. - */ - -int dumb_socketpair(SOCKET socks[2]) { - union { - struct sockaddr_in inaddr; - struct sockaddr addr; - } a; - SOCKET listener; - int e; - socklen_t addrlen = sizeof(a.inaddr); - int reuse = 1; - - if (socks == 0) { - return SOCKET_ERROR; - } - socks[0] = socks[1] = -1; - - listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (listener == -1) - return SOCKET_ERROR; - - memset(&a, 0, sizeof(a)); - a.inaddr.sin_family = AF_INET; - a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - a.inaddr.sin_port = 0; - - for (;;) { - if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR, (char *) &reuse, - (socklen_t) sizeof(reuse)) == -1) - break; - if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR) - break; - - memset(&a, 0, sizeof(a)); - if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR) - break; - // win32 getsockname may only set the port number, p=0.0005. - // ( http://msdn.microsoft.com/library/ms738543.aspx ): - a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - a.inaddr.sin_family = AF_INET; - - if (listen(listener, 1) == SOCKET_ERROR) - break; - - socks[0] = FDAPI_WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, 0); - if (socks[0] == -1) - break; - if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR) - break; - - socks[1] = accept(listener, NULL, NULL); - if (socks[1] == -1) - break; - - FDAPI_close(listener); - return 0; - } - - FDAPI_close(listener); - FDAPI_close(socks[0]); - FDAPI_close(socks[1]); - socks[0] = socks[1] = -1; - return SOCKET_ERROR; -} -#else -int dumb_socketpair(int socks[2], int dummy) { - if (socks == 0) { - errno = EINVAL; - return -1; - } - dummy = socketpair(AF_LOCAL, SOCK_STREAM, 0, socks); - if (dummy) - socks[0] = socks[1] = -1; - return dummy; -} -#endif diff --git a/src/common/shims/windows/strings.h b/src/common/shims/windows/strings.h deleted file mode 100644 index e264061c4e6e..000000000000 --- a/src/common/shims/windows/strings.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef STRINGS_H -#define STRINGS_H - -#endif /* STRINGS_H */ diff --git a/src/common/shims/windows/sys/ioctl.h b/src/common/shims/windows/sys/ioctl.h deleted file mode 100644 index 00f7a55ed77d..000000000000 --- a/src/common/shims/windows/sys/ioctl.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef IOCTL_H -#define IOCTL_H - -#endif /* IOCTL_H */ diff --git a/src/common/shims/windows/sys/mman.h b/src/common/shims/windows/sys/mman.h deleted file mode 100644 index a12df75fc7ea..000000000000 --- a/src/common/shims/windows/sys/mman.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef MMAN_H -#define MMAN_H - -#include - -#define MAP_SHARED 0x0010 /* share changes */ -#define MAP_FAILED ((void *) -1) -#define PROT_READ 0x04 /* pages can be read */ -#define PROT_WRITE 0x02 /* pages can be written */ -#define PROT_EXEC 0x01 /* pages can be executed */ - -static void *mmap(void *addr, - size_t len, - int prot, - int flags, - int fd, - off_t off) { - void *result = (void *) (-1); - if (!addr && prot == MAP_SHARED) { - /* HACK: we're assuming handle sizes can't exceed 32 bits, which is wrong... - * but works for now. */ - void *ptr = MapViewOfFile((HANDLE)(intptr_t) fd, FILE_MAP_ALL_ACCESS, - (DWORD)(off >> (CHAR_BIT * sizeof(DWORD))), - (DWORD) off, (SIZE_T) len); - if (ptr) { - result = ptr; - } - } - return result; -} -static int munmap(void *addr, size_t length) { - (void) length; - return UnmapViewOfFile(addr) ? 0 : -1; -} - -#endif /* MMAN_H */ diff --git a/src/common/shims/windows/sys/select.h b/src/common/shims/windows/sys/select.h deleted file mode 100644 index 8aef7950e399..000000000000 --- a/src/common/shims/windows/sys/select.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef SELECT_H -#define SELECT_H - -#endif /* SELECT_H */ diff --git a/src/common/shims/windows/sys/socket.h b/src/common/shims/windows/sys/socket.h deleted file mode 100644 index ba9d656bb96d..000000000000 --- a/src/common/shims/windows/sys/socket.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef SOCKET_H -#define SOCKET_H - -typedef unsigned short sa_family_t; - -#include "../../src/Win32_Interop/Win32_FDAPI.h" -#include "../../src/Win32_Interop/Win32_APIs.h" - -#define cmsghdr _WSACMSGHDR -#undef CMSG_DATA -#define CMSG_DATA WSA_CMSG_DATA -#define CMSG_SPACE WSA_CMSG_SPACE -#define CMSG_FIRSTHDR WSA_CMSG_FIRSTHDR -#define CMSG_LEN WSA_CMSG_LEN -#define CMSG_NXTHDR WSA_CMSG_NXTHDR - -#define SCM_RIGHTS 1 - -#define iovec _WSABUF -#define iov_base buf -#define iov_len len -#define msghdr _WSAMSG -#define msg_name name -#define msg_namelen namelen -#define msg_iov lpBuffers -#define msg_iovlen dwBufferCount -#define msg_control Control.buf -#define msg_controllen Control.len -#define msg_flags dwFlags - -int dumb_socketpair(SOCKET socks[2]); -ssize_t sendmsg(int sockfd, struct msghdr *msg, int flags); -ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags); -int socketpair(int domain, int type, int protocol, int sv[2]); - -#endif /* SOCKET_H */ diff --git a/src/common/shims/windows/sys/time.h b/src/common/shims/windows/sys/time.h deleted file mode 100644 index 976342bd2121..000000000000 --- a/src/common/shims/windows/sys/time.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef TIME_H -#define TIME_H - -#include /* timeval */ - -int gettimeofday_highres(struct timeval *tv, struct timezone *tz); - -static int gettimeofday(struct timeval *tv, struct timezone *tz) { - return gettimeofday_highres(tv, tz); -} - -#endif /* TIME_H */ diff --git a/src/common/shims/windows/sys/un.h b/src/common/shims/windows/sys/un.h deleted file mode 100644 index 91642683f72e..000000000000 --- a/src/common/shims/windows/sys/un.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef UN_H -#define UN_H - -#include - -struct sockaddr_un { - /** AF_UNIX. */ - sa_family_t sun_family; - /** The pathname. */ - char sun_path[108]; -}; - -#endif /* UN_H */ diff --git a/src/common/shims/windows/sys/wait.h b/src/common/shims/windows/sys/wait.h deleted file mode 100644 index 442218408f97..000000000000 --- a/src/common/shims/windows/sys/wait.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef WAIT_H -#define WAIT_H - -#endif /* WAIT_H */ diff --git a/src/common/shims/windows/unistd.h b/src/common/shims/windows/unistd.h deleted file mode 100644 index aab25417e199..000000000000 --- a/src/common/shims/windows/unistd.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef UNISTD_H -#define UNISTD_H - -extern char *optarg; -extern int optind, opterr, optopt; -int getopt(int nargc, char *const nargv[], const char *ostr); - -#include "../../src/Win32_Interop/Win32_FDAPI.h" -#define close(...) FDAPI_close(__VA_ARGS__) - -#endif /* UNISTD_H */ diff --git a/src/common/state/actor_notification_table.cc b/src/common/state/actor_notification_table.cc deleted file mode 100644 index 19cd7fddda41..000000000000 --- a/src/common/state/actor_notification_table.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "actor_notification_table.h" - -#include "common_protocol.h" -#include "redis.h" - -void publish_actor_creation_notification(DBHandle *db_handle, - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id) { - // Create a flatbuffer object to serialize and publish. - flatbuffers::FlatBufferBuilder fbb; - // Create the flatbuffers message. - auto message = CreateActorCreationNotification( - fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, driver_id), - to_flatbuf(fbb, local_scheduler_id)); - fbb.Finish(message); - - ActorCreationNotificationData *data = - (ActorCreationNotificationData *) malloc( - sizeof(ActorCreationNotificationData) + fbb.GetSize()); - data->size = fbb.GetSize(); - memcpy(&data->flatbuffer_data[0], fbb.GetBufferPointer(), fbb.GetSize()); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(data), NULL, NULL, - redis_publish_actor_creation_notification, NULL); -} - -void actor_notification_table_subscribe( - DBHandle *db_handle, - actor_notification_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - ActorNotificationTableSubscribeData *sub_data = - (ActorNotificationTableSubscribeData *) malloc( - sizeof(ActorNotificationTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_actor_notification_table_subscribe, NULL); -} - -void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id) { - redis_actor_table_mark_removed(db_handle, actor_id); -} diff --git a/src/common/state/actor_notification_table.h b/src/common/state/actor_notification_table.h deleted file mode 100644 index f6aa101cd0d0..000000000000 --- a/src/common/state/actor_notification_table.h +++ /dev/null @@ -1,74 +0,0 @@ -#ifndef ACTOR_NOTIFICATION_TABLE_H -#define ACTOR_NOTIFICATION_TABLE_H - -#include "task.h" -#include "db.h" -#include "table.h" - -/* - * ==== Subscribing to the actor notification table ==== - */ - -/* Callback for subscribing to the local scheduler table. */ -typedef void (*actor_notification_table_subscribe_callback)( - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id, - void *user_context); - -/// Publish an actor creation notification. This is published by a local -/// scheduler once it creates an actor. -/// -/// \param db_handle Database handle. -/// \param actor_id The ID of the actor that was created. -/// \param driver_id The ID of the driver that created the actor. -/// \param local_scheduler_id The ID of the local scheduler that created the -/// actor. -/// \return Void. -void publish_actor_creation_notification(DBHandle *db_handle, - const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id); - -/// Data that is needed to publish an actor creation notification. -typedef struct { - /// The size of the flatbuffer object. - int64_t size; - /// The information to be sent. - uint8_t flatbuffer_data[0]; -} ActorCreationNotificationData; - -/** - * Register a callback to process actor notification events. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the local - * scheduler event happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void actor_notification_table_subscribe( - DBHandle *db_handle, - actor_notification_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register local scheduler table subscribe callbacks - * with the state database. */ -typedef struct { - actor_notification_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} ActorNotificationTableSubscribeData; - -/** - * Marks an actor as removed. This prevents the actor from being resurrected. - * - * @param db The database handle. - * @param actor_id The actor id to mark as removed. - * @return Void. - */ -void actor_table_mark_removed(DBHandle *db_handle, ActorID actor_id); - -#endif /* ACTOR_NOTIFICATION_TABLE_H */ diff --git a/src/common/state/db.h b/src/common/state/db.h deleted file mode 100644 index ac9960b89374..000000000000 --- a/src/common/state/db.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef DB_H -#define DB_H - -#include - -#include "common.h" -#include "event_loop.h" - -typedef struct DBHandle DBHandle; - -/** - * Connect to the global system store. - * - * @param db_address The hostname to use to connect to the database. - * @param db_port The port to use to connect to the database. - * @param db_shards_addresses The list of database shard IP addresses. - * @param db_shards_ports The list of database shard ports, in the same order - * as db_shards_addresses. - * @param client_type The type of this client. - * @param node_ip_address The hostname of the client that is connecting. - * @param args A vector of extra arguments strings. They should alternate - * between the name of the argument and the value of the argument. For - * examples: "port", "1234", "socket_name", "/tmp/s1". This vector should - * have an even length. - * @return This returns a handle to the database, which must be freed with - * db_disconnect after use. - */ -DBHandle *db_connect(const std::string &db_primary_address, - int db_primary_port, - const char *client_type, - const char *node_ip_address, - const std::vector &args); - -/** - * Attach global system store connection to an event loop. Callbacks from - * queries to the global system store will trigger events in the event loop. - * - * @param db The handle to the database that is connected. - * @param loop The event loop the database gets connected to. - * @param reattach Can only be true in unit tests. If true, the database is - * reattached to the loop. - * @return Void. - */ -void db_attach(DBHandle *db, event_loop *loop, bool reattach); - -/** - * Disconnect from the global system store. - * - * @param db The database connection to close and clean up. - * @return Void. - */ -void db_disconnect(DBHandle *db); - -/** - * Free the database handle. - * - * @param db The database connection to clean up. - * @return Void. - */ -void DBHandle_free(DBHandle *db); - -/** - * Returns the db client ID. - * - * @param db The handle to the database. - * @returns int The db client ID for this connection to the database. - */ -DBClientID get_db_client_id(DBHandle *db); - -#endif diff --git a/src/common/state/db_client_table.cc b/src/common/state/db_client_table.cc deleted file mode 100644 index b31e9d8c2d3a..000000000000 --- a/src/common/state/db_client_table.cc +++ /dev/null @@ -1,90 +0,0 @@ -#include "db_client_table.h" -#include "redis.h" - -void db_client_table_remove(DBHandle *db_handle, - DBClientID db_client_id, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, db_client_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_db_client_table_remove, user_context); -} - -void db_client_table_subscribe( - DBHandle *db_handle, - db_client_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context) { - DBClientTableSubscribeData *sub_data = - (DBClientTableSubscribeData *) malloc(sizeof(DBClientTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, - (table_done_callback) done_callback, - redis_db_client_table_subscribe, user_context); -} - -const std::vector db_client_table_get_ip_addresses( - DBHandle *db_handle, - const std::vector &manager_ids) { - /* We time this function because in the past this loop has taken multiple - * seconds under stressful situations on hundreds of machines causing the - * plasma manager to die (because it went too long without sending - * heartbeats). */ - int64_t start_time = current_time_ms(); - - /* Construct the manager vector from the flatbuffers object. */ - std::vector manager_vector; - - for (auto const &manager_id : manager_ids) { - DBClient client = redis_cache_get_db_client(db_handle, manager_id); - RAY_CHECK(!client.manager_address.empty()); - if (client.is_alive) { - manager_vector.push_back(client.manager_address); - } - } - - int64_t end_time = current_time_ms(); - if (end_time - start_time > RayConfig::instance().max_time_for_loop()) { - RAY_LOG(WARNING) << "calling redis_get_cached_db_client in a loop in with " - << manager_ids.size() << " manager IDs took " - << end_time - start_time << " milliseconds."; - } - - return manager_vector; -} - -void db_client_table_update_cache_callback(DBClient *db_client, - void *user_context) { - DBHandle *db_handle = (DBHandle *) user_context; - redis_cache_set_db_client(db_handle, *db_client); -} - -void db_client_table_cache_init(DBHandle *db_handle) { - db_client_table_subscribe(db_handle, db_client_table_update_cache_callback, - db_handle, NULL, NULL, NULL); -} - -DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id) { - RAY_CHECK(!client_id.is_nil()); - return redis_cache_get_db_client(db_handle, client_id); -} - -void plasma_manager_send_heartbeat(DBHandle *db_handle) { - RetryInfo heartbeat_retry; - heartbeat_retry.num_retries = 0; - heartbeat_retry.timeout = - RayConfig::instance().heartbeat_timeout_milliseconds(); - heartbeat_retry.fail_callback = NULL; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(NULL), - (RetryInfo *) &heartbeat_retry, NULL, - redis_plasma_manager_send_heartbeat, NULL); -} diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h deleted file mode 100644 index d140ba770eee..000000000000 --- a/src/common/state/db_client_table.h +++ /dev/null @@ -1,120 +0,0 @@ -#ifndef DB_CLIENT_TABLE_H -#define DB_CLIENT_TABLE_H - -#include - -#include "db.h" -#include "table.h" - -typedef void (*db_client_table_done_callback)(DBClientID db_client_id, - void *user_context); - -/** - * Remove a client from the db clients table. - * - * @param db_handle Database handle. - * @param db_client_id The database client ID to remove. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - * - */ -void db_client_table_remove(DBHandle *db_handle, - DBClientID db_client_id, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context); - -/* - * ==== Subscribing to the db client table ==== - */ - -/* An entry in the db client table. */ -typedef struct { - /** The database client ID. */ - DBClientID id; - /** The database client type. */ - std::string client_type; - /** An optional auxiliary address for the plasma manager associated with this - * database client. */ - std::string manager_address; - /** Whether or not the database client exists. If this is false for an entry, - * then it will never again be true. */ - bool is_alive; -} DBClient; - -/* Callback for subscribing to the db client table. */ -typedef void (*db_client_table_subscribe_callback)(DBClient *db_client, - void *user_context); - -/** - * Register a callback for a db client table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the db client - * table is updated. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void db_client_table_subscribe( - DBHandle *db_handle, - db_client_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - db_client_table_done_callback done_callback, - void *user_context); - -/* Data that is needed to register db client table subscribe callbacks with the - * state database. */ -typedef struct { - db_client_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} DBClientTableSubscribeData; - -const std::vector db_client_table_get_ip_addresses( - DBHandle *db, - const std::vector &manager_ids); - -/** - * Initialize the db client cache. The cache is updated with each notification - * from the db client table. - * - * @param db_handle Database handle. - * @return Void. - */ -void db_client_table_cache_init(DBHandle *db_handle); - -/** - * Get a db client from the cache. If the requested client is not there, - * request the latest entry from the db client table. - * - * @param db_handle Database handle. - * @param client_id The ID of the client to look up in the cache. - * @return The database client in the cache. - */ -DBClient db_client_table_cache_get(DBHandle *db_handle, DBClientID client_id); - -/* - * ==== Plasma manager heartbeats ==== - */ - -/** - * Start sending heartbeats to the plasma_managers channel. Each - * heartbeat contains this database client's ID. Heartbeats can be subscribed - * to through the plasma_managers channel. Once called, this "retries" the - * heartbeat operation forever, every heartbeat_timeout_milliseconds - * milliseconds. - * - * @param db_handle Database handle. - * @return Void. - */ -void plasma_manager_send_heartbeat(DBHandle *db_handle); - -#endif /* DB_CLIENT_TABLE_H */ diff --git a/src/common/state/driver_table.cc b/src/common/state/driver_table.cc deleted file mode 100644 index b8732e9863b2..000000000000 --- a/src/common/state/driver_table.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "driver_table.h" -#include "redis.h" - -void driver_table_subscribe(DBHandle *db_handle, - driver_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - DriverTableSubscribeData *sub_data = - (DriverTableSubscribeData *) malloc(sizeof(DriverTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_driver_table_subscribe, NULL); -} - -void driver_table_send_driver_death(DBHandle *db_handle, - WorkerID driver_id, - RetryInfo *retry) { - init_table_callback(db_handle, driver_id, __func__, - new CommonCallbackData(NULL), retry, NULL, - redis_driver_table_send_driver_death, NULL); -} diff --git a/src/common/state/driver_table.h b/src/common/state/driver_table.h deleted file mode 100644 index c8c6a6c32382..000000000000 --- a/src/common/state/driver_table.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef DRIVER_TABLE_H -#define DRIVER_TABLE_H - -#include "db.h" -#include "table.h" -#include "task.h" - -/* - * ==== Subscribing to the driver table ==== - */ - -/* Callback for subscribing to the driver table. */ -typedef void (*driver_table_subscribe_callback)(WorkerID driver_id, - void *user_context); - -/** - * Register a callback for a driver table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the driver event - * happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void driver_table_subscribe(DBHandle *db_handle, - driver_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register driver table subscribe callbacks with the - * state database. */ -typedef struct { - driver_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} DriverTableSubscribeData; - -/** - * Send driver death update to all subscribers. - * - * @param db_handle Database handle. - * @param driver_id The ID of the driver that died. - * @param retry Information about retrying the request to the database. - */ -void driver_table_send_driver_death(DBHandle *db_handle, - WorkerID driver_id, - RetryInfo *retry); - -#endif /* DRIVER_TABLE_H */ diff --git a/src/common/state/error_table.cc b/src/common/state/error_table.cc deleted file mode 100644 index d0fd9bdff5e9..000000000000 --- a/src/common/state/error_table.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "error_table.h" -#include "redis.h" - -const char *error_types[] = {"object_hash_mismatch", "put_reconstruction", - "worker_died", "actor_not_created"}; - -void push_error(DBHandle *db_handle, - DBClientID driver_id, - ErrorIndex error_type, - const std::string &error_message) { - int64_t message_size = error_message.size(); - - /* Allocate a struct to hold the error information. */ - ErrorInfo *info = (ErrorInfo *) malloc(sizeof(ErrorInfo) + message_size); - info->driver_id = driver_id; - info->error_type = error_type; - info->error_key = UniqueID::from_random(); - info->size = message_size; - memcpy(info->error_message, error_message.data(), message_size); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(info), NULL, NULL, - redis_push_error, NULL); -} diff --git a/src/common/state/error_table.h b/src/common/state/error_table.h deleted file mode 100644 index 908d7f4d0eaa..000000000000 --- a/src/common/state/error_table.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef ERROR_TABLE_H -#define ERROR_TABLE_H - -#include "db.h" -#include "table.h" - -/// An ErrorIndex may be used as an index into error_types. -enum class ErrorIndex : int32_t { - /// An object was added with a different hash from the existing one. - OBJECT_HASH_MISMATCH = 0, - /// An object that was created through a ray.put is lost. - PUT_RECONSTRUCTION, - /// A worker died or was killed while executing a task. - WORKER_DIED, - /// An actor hasn't been created for a while. - ACTOR_NOT_CREATED, - /// The total number of error types. - MAX -}; - -/// Data that is needed to push an error. -typedef struct { - /// The ID of the driver to push the error to. - DBClientID driver_id; - /// An index into the error_types array indicating the type of the error. - ErrorIndex error_type; - /// The key to use for the error message in Redis. - UniqueID error_key; - /// The length of the error message. - int64_t size; - /// The error message. - uint8_t error_message[0]; -} ErrorInfo; - -extern const char *error_types[]; - -/// Push an error to the given Python driver. -/// -/// \param db_handle Database handle. -/// \param driver_id The ID of the Python driver to push the error to. -/// \param error_type An index specifying the type of the error. This should -/// be a value from the ErrorIndex enum. -/// \param error_message The error message to print. -/// \return Void. -void push_error(DBHandle *db_handle, - DBClientID driver_id, - ErrorIndex error_type, - const std::string &error_message); - -#endif diff --git a/src/common/state/local_scheduler_table.cc b/src/common/state/local_scheduler_table.cc deleted file mode 100644 index 075d52102807..000000000000 --- a/src/common/state/local_scheduler_table.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "local_scheduler_table.h" - -#include "common_protocol.h" -#include "redis.h" - -void local_scheduler_table_subscribe( - DBHandle *db_handle, - local_scheduler_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry) { - LocalSchedulerTableSubscribeData *sub_data = - (LocalSchedulerTableSubscribeData *) malloc( - sizeof(LocalSchedulerTableSubscribeData)); - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(sub_data), retry, NULL, - redis_local_scheduler_table_subscribe, NULL); -} - -void local_scheduler_table_send_info(DBHandle *db_handle, - LocalSchedulerInfo *info, - RetryInfo *retry) { - /* Create a flatbuffer object to serialize and publish. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto message = CreateLocalSchedulerInfoMessage( - fbb, to_flatbuf(fbb, db_handle->client), info->total_num_workers, - info->task_queue_length, info->available_workers, - map_to_flatbuf(fbb, info->static_resources), - map_to_flatbuf(fbb, info->dynamic_resources), false); - fbb.Finish(message); - - LocalSchedulerTableSendInfoData *data = - (LocalSchedulerTableSendInfoData *) malloc( - sizeof(LocalSchedulerTableSendInfoData) + fbb.GetSize()); - data->size = fbb.GetSize(); - memcpy(&data->flatbuffer_data[0], fbb.GetBufferPointer(), fbb.GetSize()); - - init_table_callback(db_handle, UniqueID::nil(), __func__, - new CommonCallbackData(data), retry, NULL, - redis_local_scheduler_table_send_info, NULL); -} - -void local_scheduler_table_disconnect(DBHandle *db_handle) { - redis_local_scheduler_table_disconnect(db_handle); -} diff --git a/src/common/state/local_scheduler_table.h b/src/common/state/local_scheduler_table.h deleted file mode 100644 index 239b84d0fa48..000000000000 --- a/src/common/state/local_scheduler_table.h +++ /dev/null @@ -1,98 +0,0 @@ -#ifndef LOCAL_SCHEDULER_TABLE_H -#define LOCAL_SCHEDULER_TABLE_H - -#include - -#include "db.h" -#include "table.h" -#include "task.h" - -/** This struct is sent with heartbeat messages from the local scheduler to the - * global scheduler, and it contains information about the load on the local - * scheduler. */ -typedef struct { - /** The total number of workers that are connected to this local scheduler. */ - int total_num_workers; - /** The number of tasks queued in this local scheduler. */ - int task_queue_length; - /** The number of workers that are available and waiting for tasks. */ - int available_workers; - /** The resource vector of resources generally available to this local - * scheduler. */ - std::unordered_map static_resources; - /** The resource vector of resources currently available to this local - * scheduler. */ - std::unordered_map dynamic_resources; - /** Whether the local scheduler is dead. If true, then all other fields - * should be ignored. */ - bool is_dead; -} LocalSchedulerInfo; - -/* - * ==== Subscribing to the local scheduler table ==== - */ - -/* Callback for subscribing to the local scheduler table. */ -typedef void (*local_scheduler_table_subscribe_callback)( - DBClientID client_id, - LocalSchedulerInfo info, - void *user_context); - -/** - * Register a callback for a local scheduler table event. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the local - * scheduler event happens. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void local_scheduler_table_subscribe( - DBHandle *db_handle, - local_scheduler_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry); - -/* Data that is needed to register local scheduler table subscribe callbacks - * with the state database. */ -typedef struct { - local_scheduler_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} LocalSchedulerTableSubscribeData; - -/** - * Send a heartbeat to all subscribers to the local scheduler table. This - * heartbeat contains some information about the load on the local scheduler. - * - * @param db_handle Database handle. - * @param info Information about the local scheduler, including the load on the - * local scheduler. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void local_scheduler_table_send_info(DBHandle *db_handle, - LocalSchedulerInfo *info, - RetryInfo *retry); - -/* Data that is needed to publish local scheduler heartbeats to the local - * scheduler table. */ -typedef struct { - /* The size of the flatbuffer object. */ - int64_t size; - /* The information to be sent. */ - uint8_t flatbuffer_data[0]; -} LocalSchedulerTableSendInfoData; - -/** - * Send a null heartbeat to all subscribers to the local scheduler table to - * notify them that we are about to exit. This operation is performed - * synchronously. - * - * @param db_handle Database handle. - * @return Void. - */ -void local_scheduler_table_disconnect(DBHandle *db_handle); - -#endif /* LOCAL_SCHEDULER_TABLE_H */ diff --git a/src/common/state/object_table.cc b/src/common/state/object_table.cc deleted file mode 100644 index fcd527e62f6a..000000000000 --- a/src/common/state/object_table.cc +++ /dev/null @@ -1,119 +0,0 @@ -#include "object_table.h" -#include "redis.h" - -void object_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_object_table_lookup, user_context); -} - -void object_table_add(DBHandle *db_handle, - ObjectID object_id, - int64_t object_size, - unsigned char digest[], - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - - ObjectTableAddData *info = - (ObjectTableAddData *) malloc(sizeof(ObjectTableAddData)); - info->object_size = object_size; - memcpy(&info->digest[0], digest, DIGEST_SIZE); - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(info), retry, - (table_done_callback) done_callback, - redis_object_table_add, user_context); -} - -void object_table_remove(DBHandle *db_handle, - ObjectID object_id, - DBClientID *client_id, - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - /* Copy the client ID, if one was provided. */ - DBClientID *client_id_copy = NULL; - if (client_id != NULL) { - client_id_copy = (DBClientID *) malloc(sizeof(DBClientID)); - *client_id_copy = *client_id; - } - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(client_id_copy), retry, - (table_done_callback) done_callback, - redis_object_table_remove, user_context); -} - -void object_table_subscribe_to_notifications( - DBHandle *db_handle, - bool subscribe_all, - object_table_object_available_callback object_available_callback, - void *subscribe_context, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context) { - RAY_CHECK(db_handle != NULL); - ObjectTableSubscribeData *sub_data = - (ObjectTableSubscribeData *) malloc(sizeof(ObjectTableSubscribeData)); - sub_data->object_available_callback = object_available_callback; - sub_data->subscribe_context = subscribe_context; - sub_data->subscribe_all = subscribe_all; - - init_table_callback( - db_handle, ObjectID::nil(), __func__, new CommonCallbackData(sub_data), - retry, (table_done_callback) done_callback, - redis_object_table_subscribe_to_notifications, user_context); -} - -void object_table_request_notifications(DBHandle *db_handle, - int num_object_ids, - ObjectID object_ids[], - RetryInfo *retry) { - RAY_CHECK(db_handle != NULL); - RAY_CHECK(num_object_ids > 0); - ObjectTableRequestNotificationsData *data = - (ObjectTableRequestNotificationsData *) malloc( - sizeof(ObjectTableRequestNotificationsData) + - num_object_ids * sizeof(ObjectID)); - data->num_object_ids = num_object_ids; - memcpy(data->object_ids, object_ids, num_object_ids * sizeof(ObjectID)); - - init_table_callback(db_handle, ObjectID::nil(), __func__, - new CommonCallbackData(data), retry, NULL, - redis_object_table_request_notifications, NULL); -} - -void result_table_add(DBHandle *db_handle, - ObjectID object_id, - TaskID task_id, - bool is_put, - RetryInfo *retry, - result_table_done_callback done_callback, - void *user_context) { - ResultTableAddInfo *info = - (ResultTableAddInfo *) malloc(sizeof(ResultTableAddInfo)); - info->task_id = task_id; - info->is_put = is_put; - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(info), retry, - (table_done_callback) done_callback, - redis_result_table_add, user_context); -} - -void result_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - result_table_lookup_callback done_callback, - void *user_context) { - init_table_callback(db_handle, object_id, __func__, - new CommonCallbackData(NULL), retry, - (table_done_callback) done_callback, - redis_result_table_lookup, user_context); -} diff --git a/src/common/state/object_table.h b/src/common/state/object_table.h deleted file mode 100644 index 77a299dfd30a..000000000000 --- a/src/common/state/object_table.h +++ /dev/null @@ -1,242 +0,0 @@ -#ifndef OBJECT_TABLE_H -#define OBJECT_TABLE_H - -#include "common.h" -#include "table.h" -#include "db.h" -#include "task.h" - -/* - * ==== Lookup call and callback ==== - */ - -/* Callback called when the lookup completes. The callback should free - * the manager_vector array, but NOT the strings they are pointing to. If there - * was no entry at all for the object (the object had never been created - * before), then never_created will be true. - */ -typedef void (*object_table_lookup_done_callback)( - ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context); - -/* Callback called when object ObjectID is available. */ -typedef void (*object_table_object_available_callback)( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_ids, - void *user_context); - -/** - * Return the list of nodes storing object_id in their plasma stores. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object being looked up. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void object_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context); - -/* - * ==== Add object call and callback ==== - */ - -/** - * Callback called when the object add/remove operation completes. - * - * @param object_id The ID of the object that was added or removed. - * @param success Whether the operation was successful or not. If this is false - * and the operation was an addition, the object was added, but there - * was a hash mismatch. - * @param user_context The user context that was passed into the add/remove - * call. - */ -typedef void (*object_table_done_callback)(ObjectID object_id, - bool success, - void *user_context); - -/** - * Add the plasma manager that created the db_handle to the - * list of plasma managers that have the object_id. - * - * @param db_handle Handle to db. - * @param object_id Object unique identifier. - * @param data_size Object data size. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when lookup completes. - * @param user_context User context to be passed in the callbacks. - * @return Void. - */ -void object_table_add(DBHandle *db_handle, - ObjectID object_id, - int64_t object_size, - unsigned char digest[], - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context); - -/** Data that is needed to add new objects to the object table. */ -typedef struct { - int64_t object_size; - unsigned char digest[DIGEST_SIZE]; -} ObjectTableAddData; - -/* - * ==== Remove object call and callback ==== - */ - -/** - * Object remove function. - * - * @param db_handle Handle to db. - * @param object_id Object unique identifier. - * @param client_id A pointer to the database client ID to remove. If this is - * set to NULL, then the client ID associated with db_handle will be - * removed. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when lookup completes. - * @param user_context User context to be passed in the callbacks. - * @return Void. - */ -void object_table_remove(DBHandle *db_handle, - ObjectID object_id, - DBClientID *client_id, - RetryInfo *retry, - object_table_done_callback done_callback, - void *user_context); - -/* - * ==== Subscribe to be announced when new object available ==== - */ - -/** - * Set up a client-specific channel for receiving notifications about available - * objects from the object table. The callback will be called once per - * notification received on this channel. - * - * @param db_handle Handle to db. - * @param object_available_callback Callback to be called when new object - * becomes available. - * @param subscribe_context Caller context which will be passed to the - * object_available_callback. - * @param retry Information about retrying the request to the database. - * @param done_callback Callback to be called when subscription is installed. - * This is only used for the tests. - * @param user_context User context to be passed into the done callback. This is - * only used for the tests. - * @return Void. - */ -void object_table_subscribe_to_notifications( - DBHandle *db_handle, - bool subscribe_all, - object_table_object_available_callback object_available_callback, - void *subscribe_context, - RetryInfo *retry, - object_table_lookup_done_callback done_callback, - void *user_context); - -/** - * Request notifications about the availability of some objects from the object - * table. The notifications will be published to this client's object - * notification channel, which was set up by the method - * object_table_subscribe_to_notifications. - * - * @param db_handle Handle to db. - * @param object_ids The object IDs to receive notifications about. - * @param retry Information about retrying the request to the database. - * @return Void. - */ -void object_table_request_notifications(DBHandle *db, - int num_object_ids, - ObjectID object_ids[], - RetryInfo *retry); - -/** Data that is needed to run object_request_notifications requests. */ -typedef struct { - /** The number of object IDs. */ - int num_object_ids; - /** This field is used to store a variable number of object IDs. */ - ObjectID object_ids[0]; -} ObjectTableRequestNotificationsData; - -/** Data that is needed to register new object available callbacks with the - * state database. */ -typedef struct { - bool subscribe_all; - object_table_object_available_callback object_available_callback; - void *subscribe_context; -} ObjectTableSubscribeData; - -/* - * ==== Result table ==== - */ - -/** - * Callback called when the add/remove operation for a result table entry - * completes. */ -typedef void (*result_table_done_callback)(ObjectID object_id, - void *user_context); - -/** Information about a result table entry to add. */ -typedef struct { - /** The task ID of the task that created the requested object. */ - TaskID task_id; - /** True if the object was created through a put, and false if created by - * return value. */ - bool is_put; -} ResultTableAddInfo; - -/** - * Add information about a new object to the object table. This - * is immutable information like the ID of the task that - * created the object. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object to add. - * @param task_id ID of the task that creates this object. - * @param is_put A boolean that is true if the object was created through a - * ray.put, and false if the object was created by return value. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void result_table_add(DBHandle *db_handle, - ObjectID object_id, - TaskID task_id, - bool is_put, - RetryInfo *retry, - result_table_done_callback done_callback, - void *user_context); - -/** Callback called when the result table lookup completes. */ -typedef void (*result_table_lookup_callback)(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context); - -/** - * Lookup the task that created an object in the result table. The return value - * is the task ID. - * - * @param db_handle Handle to object_table database. - * @param object_id ID of the object to lookup. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Context passed by the caller. - * @return Void. - */ -void result_table_lookup(DBHandle *db_handle, - ObjectID object_id, - RetryInfo *retry, - result_table_lookup_callback done_callback, - void *user_context); - -#endif /* OBJECT_TABLE_H */ diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc deleted file mode 100644 index 17a3c8ce2d3a..000000000000 --- a/src/common/state/redis.cc +++ /dev/null @@ -1,1692 +0,0 @@ -/* Redis implementation of the global state store */ - -#include -#include -#include -#include - -extern "C" { -/* Including hiredis here is necessary on Windows for typedefs used in ae.h. */ -#include "hiredis/hiredis.h" -#include "hiredis/adapters/ae.h" -} - -#include "common.h" -#include "db.h" -#include "db_client_table.h" -#include "actor_notification_table.h" -#include "driver_table.h" -#include "local_scheduler_table.h" -#include "object_table.h" -#include "task.h" -#include "task_table.h" -#include "error_table.h" -#include "event_loop.h" -#include "redis.h" -#include "io.h" -#include "net.h" - -#include "format/common_generated.h" - -#include "common_protocol.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -#define CHECK_REDIS_CONNECT(CONTEXT_TYPE, context, M, ...) \ - do { \ - CONTEXT_TYPE *_context = (context); \ - if (!_context) { \ - RAY_LOG(FATAL) << "could not allocate redis context"; \ - } \ - if (_context->err) { \ - RAY_LOG(ERROR) << M; \ - LOG_REDIS_ERROR(_context, ""); \ - exit(-1); \ - } \ - } while (0) - -/** - * A header for callbacks of a single Redis asynchronous command. The user must - * pass in the table operation's timer ID as the asynchronous command's - * privdata field when executing the asynchronous command. The user must define - * variable names for DB and CB_DATA. After this piece of code runs, DB - * will hold a reference to the database handle, CB_DATA will hold a reference - * to the callback data for this table operation. The user must pass in the - * redisReply pointer as the REPLY argument. - * - * This header also short-circuits the entire callback if: (1) there was no - * reply from Redis, or (2) the callback data for this table operation was - * already removed, meaning that the operation was already marked as succeeded - * or failed. - */ -#define REDIS_CALLBACK_HEADER(DB, CB_DATA, REPLY) \ - if ((REPLY) == NULL) { \ - return; \ - } \ - DBHandle *DB = (DBHandle *) c->data; \ - TableCallbackData *CB_DATA = outstanding_callbacks_find((int64_t) privdata); \ - if (CB_DATA == NULL) { \ - /* the callback data structure has been \ - * already freed; just ignore this reply */ \ - return; \ - } \ - do { \ - } while (0) - -redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id) { - /* NOTE: The hash function used here must match the one in - * PyObjectID_redis_shard_hash in src/common/lib/python/common_extension.cc. - * Changes to the hash function should only be made through - * std::hash in src/common/common.h */ - std::hash index; - return db->contexts[index(id) % db->contexts.size()]; -} - -redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id) { - std::hash index; - return db->subscribe_contexts[index(id) % db->subscribe_contexts.size()]; -} - -void get_redis_shards(redisContext *context, - std::vector &db_shards_addresses, - std::vector &db_shards_ports) { - /* Get the total number of Redis shards in the system. */ - int num_attempts = 0; - redisReply *reply = NULL; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - /* Try to read the number of Redis shards from the primary shard. If the - * entry is present, exit. */ - reply = (redisReply *) redisCommand(context, "GET NumRedisShards"); - if (reply->type != REDIS_REPLY_NIL) { - break; - } - - /* Sleep for a little, and try again if the entry isn't there yet. */ - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - continue; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "No entry found for NumRedisShards"; - RAY_CHECK(reply->type == REDIS_REPLY_STRING) - << "Expected string, found Redis type " << reply->type - << " for NumRedisShards"; - int num_redis_shards = atoi(reply->str); - RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " - << "found " << num_redis_shards; - freeReplyObject(reply); - - /* Get the addresses of all of the Redis shards. */ - num_attempts = 0; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - /* Try to read the Redis shard locations from the primary shard. If we find - * that all of them are present, exit. */ - reply = (redisReply *) redisCommand(context, "LRANGE RedisShards 0 -1"); - if (static_cast(reply->elements) == num_redis_shards) { - break; - } - - /* Sleep for a little, and try again if not all Redis shard addresses have - * been added yet. */ - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - continue; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "Expected " << num_redis_shards << " Redis shard addresses, found " - << reply->elements; - - /* Parse the Redis shard addresses. */ - char db_shard_address[16]; - int db_shard_port; - for (size_t i = 0; i < reply->elements; ++i) { - /* Parse the shard addresses and ports. */ - RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING); - RAY_CHECK(parse_ip_addr_port(reply->element[i]->str, db_shard_address, - &db_shard_port) == 0); - db_shards_addresses.push_back(std::string(db_shard_address)); - db_shards_ports.push_back(db_shard_port); - } - freeReplyObject(reply); -} - -void db_connect_shard(const std::string &db_address, - int db_port, - DBClientID client, - const char *client_type, - const char *node_ip_address, - const std::vector &args, - DBHandle *db, - redisAsyncContext **context_out, - redisAsyncContext **subscribe_context_out, - redisContext **sync_context_out) { - /* Synchronous connection for initial handshake */ - redisReply *reply; - int connection_attempts = 0; - redisContext *sync_context = redisConnect(db_address.c_str(), db_port); - while (sync_context == NULL || sync_context->err) { - if (connection_attempts >= - RayConfig::instance().redis_db_connect_retries()) { - break; - } - RAY_LOG(WARNING) << "Failed to connect to Redis, retrying."; - /* Sleep for a little. */ - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - sync_context = redisConnect(db_address.c_str(), db_port); - connection_attempts += 1; - } - CHECK_REDIS_CONNECT(redisContext, sync_context, - "could not establish synchronous connection to redis " - "%s:%d", - db_address.c_str(), db_port); - /* Configure Redis to generate keyspace notifications for list events. This - * should only need to be done once (by whoever started Redis), but since - * Redis may be started in multiple places (e.g., for testing or when starting - * processes by hand), it is easier to do it multiple times. */ - reply = (redisReply *) redisCommand(sync_context, - "CONFIG SET notify-keyspace-events Kl"); - RAY_CHECK(reply != NULL) << "db_connect failed on CONFIG SET"; - freeReplyObject(reply); - /* Also configure Redis to not run in protected mode, so clients on other - * hosts can connect to it. */ - reply = - (redisReply *) redisCommand(sync_context, "CONFIG SET protected-mode no"); - RAY_CHECK(reply != NULL) << "db_connect failed on CONFIG SET"; - freeReplyObject(reply); - - /* Construct the argument arrays for RAY.CONNECT. */ - int argc = args.size() + 4; - const char **argv = (const char **) malloc(sizeof(char *) * argc); - size_t *argvlen = (size_t *) malloc(sizeof(size_t) * argc); - /* Set the command name argument. */ - argv[0] = "RAY.CONNECT"; - argvlen[0] = strlen(argv[0]); - /* Set the client ID argument. */ - argv[1] = (char *) client.data(); - argvlen[1] = sizeof(client); - /* Set the node IP address argument. */ - argv[2] = node_ip_address; - argvlen[2] = strlen(node_ip_address); - /* Set the client type argument. */ - argv[3] = client_type; - argvlen[3] = strlen(client_type); - /* Set the remaining arguments. */ - for (size_t i = 0; i < args.size(); ++i) { - argv[4 + i] = args[i].c_str(); - argvlen[4 + i] = strlen(args[i].c_str()); - } - - /* Register this client with Redis. RAY.CONNECT is a custom Redis command that - * we've defined. */ - reply = (redisReply *) redisCommandArgv(sync_context, argc, argv, argvlen); - RAY_CHECK(reply != NULL) << "db_connect failed on RAY.CONNECT"; - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - freeReplyObject(reply); - free(argv); - free(argvlen); - - *sync_context_out = sync_context; - - /* Establish connection for control data. */ - redisAsyncContext *context = redisAsyncConnect(db_address.c_str(), db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, context, - "could not establish asynchronous connection to redis " - "%s:%d", - db_address.c_str(), db_port); - context->data = (void *) db; - *context_out = context; - - /* Establish async connection for subscription. */ - redisAsyncContext *subscribe_context = - redisAsyncConnect(db_address.c_str(), db_port); - CHECK_REDIS_CONNECT(redisAsyncContext, subscribe_context, - "could not establish asynchronous subscription " - "connection to redis %s:%d", - db_address.c_str(), db_port); - subscribe_context->data = (void *) db; - *subscribe_context_out = subscribe_context; -} - -DBHandle *db_connect(const std::string &db_primary_address, - int db_primary_port, - const char *client_type, - const char *node_ip_address, - const std::vector &args) { - /* Check that the number of args is even. These args will be passed to the - * RAY.CONNECT Redis command, which takes arguments in pairs. */ - if (args.size() % 2 != 0) { - RAY_LOG(FATAL) << "The number of extra args must be divisible by two."; - } - - /* Create a client ID for this client. */ - DBClientID client = DBClientID::from_random(); - - DBHandle *db = new DBHandle(); - - db->client_type = strdup(client_type); - db->client = client; - - redisAsyncContext *context; - redisAsyncContext *subscribe_context; - redisContext *sync_context; - - /* Connect to the primary redis instance. */ - db_connect_shard(db_primary_address, db_primary_port, client, client_type, - node_ip_address, args, db, &context, &subscribe_context, - &sync_context); - db->context = context; - db->subscribe_context = subscribe_context; - db->sync_context = sync_context; - - /* Get the shard locations. */ - std::vector db_shards_addresses; - std::vector db_shards_ports; - get_redis_shards(db->sync_context, db_shards_addresses, db_shards_ports); - RAY_CHECK(db_shards_addresses.size() > 0) << "No Redis shards found"; - /* Connect to the shards. */ - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - db_connect_shard(db_shards_addresses[i], db_shards_ports[i], client, - client_type, node_ip_address, args, db, &context, - &subscribe_context, &sync_context); - db->contexts.push_back(context); - db->subscribe_contexts.push_back(subscribe_context); - redisFree(sync_context); - } - - return db; -} - -void DBHandle_free(DBHandle *db) { - /* Clean up the primary Redis connection state. */ - redisFree(db->sync_context); - redisAsyncFree(db->context); - redisAsyncFree(db->subscribe_context); - - /* Clean up the Redis shards. */ - RAY_CHECK(db->contexts.size() == db->subscribe_contexts.size()); - for (size_t i = 0; i < db->contexts.size(); ++i) { - redisAsyncFree(db->contexts[i]); - redisAsyncFree(db->subscribe_contexts[i]); - } - - free(db->client_type); - delete db; -} - -void db_disconnect(DBHandle *db) { - /* Notify others that this client is disconnecting from Redis. If a client of - * the same type on the same node wants to reconnect again, they must - * reconnect and get assigned a different client ID. */ - redisReply *reply = - (redisReply *) redisCommand(db->sync_context, "RAY.DISCONNECT %b", - db->client.data(), sizeof(db->client)); - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - freeReplyObject(reply); - - DBHandle_free(db); -} - -void db_attach(DBHandle *db, event_loop *loop, bool reattach) { - db->loop = loop; - /* Attach primary redis instance to the event loop. */ - int err = redisAeAttach(loop, db->context); - /* If the database is reattached in the tests, redis normally gives - * an error which we can safely ignore. */ - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - err = redisAeAttach(loop, db->subscribe_context); - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - /* Attach other redis shards to the event loop. */ - RAY_CHECK(db->contexts.size() == db->subscribe_contexts.size()); - for (size_t i = 0; i < db->contexts.size(); ++i) { - int err = redisAeAttach(loop, db->contexts[i]); - /* If the database is reattached in the tests, redis normally gives - * an error which we can safely ignore. */ - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - err = redisAeAttach(loop, db->subscribe_contexts[i]); - if (!reattach) { - RAY_CHECK(err == REDIS_OK) << "failed to attach the event loop"; - } - } -} - -/* - * ==== object_table callbacks ==== - */ - -void redis_object_table_add_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - bool success = (strcmp(reply->str, "hash mismatch") != 0); - if (!success) { - /* If our object hash doesn't match the one recorded in the table, report - * the error back to the user and exit immediately. */ - RAY_LOG(WARNING) << "Found objects with different value but same object " - << "ID, most likely because a nondeterministic task was " - << "executed twice, either for reconstruction or for " - << "speculation."; - } else { - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " - << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - } - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - object_table_done_callback done_callback = - (object_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, success, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_add(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectTableAddData *info = (ObjectTableAddData *) callback_data->data->Get(); - ObjectID obj_id = callback_data->id; - int64_t object_size = info->object_size; - unsigned char *digest = info->digest; - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand( - context, redis_object_table_add_callback, - (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_ADD %b %lld %b %b", - obj_id.data(), sizeof(obj_id), (long long) object_size, digest, - (size_t) DIGEST_SIZE, db->client.data(), sizeof(db->client)); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_object_table_add"); - } -} - -void redis_object_table_remove_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - if (strcmp(reply->str, "object not found") == 0) { - /* If our object entry was not in the table, it's probably a race - * condition with an object_table_add. */ - return; - } - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - object_table_done_callback done_callback = - (object_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, true, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_remove(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectID obj_id = callback_data->id; - /* If the caller provided a manager ID to delete, use it. Otherwise, use our - * own client ID as the ID to delete. */ - DBClientID *client_id = (DBClientID *) callback_data->data->Get(); - if (client_id == NULL) { - client_id = &db->client; - } - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand( - context, redis_object_table_remove_callback, - (void *) callback_data->timer_id, "RAY.OBJECT_TABLE_REMOVE %b %b", - obj_id.data(), sizeof(obj_id), client_id->data(), sizeof(*client_id)); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_object_table_remove"); - } -} - -void redis_object_table_lookup(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - - ObjectID obj_id = callback_data->id; - - redisAsyncContext *context = get_redis_context(db, obj_id); - - int status = redisAsyncCommand(context, redis_object_table_lookup_callback, - (void *) callback_data->timer_id, - "RAY.OBJECT_TABLE_LOOKUP %b", obj_id.data(), - sizeof(obj_id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in object_table lookup"); - } -} - -void redis_result_table_add_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Check that the command succeeded. */ - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strncmp(reply->str, "OK", strlen("OK")) == 0) << "reply->str is " - << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback) { - result_table_done_callback done_callback = - (result_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - destroy_timer_callback(db->loop, callback_data); -} - -void redis_result_table_add(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - ObjectID id = callback_data->id; - ResultTableAddInfo *info = (ResultTableAddInfo *) callback_data->data->Get(); - int is_put = info->is_put ? 1 : 0; - - redisAsyncContext *context = get_redis_context(db, id); - - /* Add the result entry to the result table. */ - int status = - redisAsyncCommand(context, redis_result_table_add_callback, - (void *) callback_data->timer_id, - "RAY.RESULT_TABLE_ADD %b %b %d", id.data(), sizeof(id), - info->task_id.data(), sizeof(info->task_id), is_put); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "Error in result table add"); - } -} - -/* This allocates a task which must be freed by the caller, unless the returned - * task is NULL. This is used by both redis_result_table_lookup_callback and - * redis_task_table_get_task_callback. */ -Task *parse_and_construct_task_from_redis_reply(redisReply *reply) { - Task *task = NULL; - if (reply->type == REDIS_REPLY_NIL) { - /* There is no task in the reply, so return NULL. */ - } else if (reply->type == REDIS_REPLY_STRING) { - /* The reply is a flatbuffer TaskReply object. Parse it and construct the - * task. */ - auto message = flatbuffers::GetRoot(reply->str); - TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); - int64_t task_spec_size = message->task_spec()->size(); - auto execution_dependencies = - flatbuffers::GetRoot( - message->execution_dependencies()->data()); - task = Task_alloc( - spec, task_spec_size, static_cast(message->state()), - from_flatbuf(*message->local_scheduler_id()), - from_flatbuf(*execution_dependencies->execution_dependencies())); - } else { - RAY_LOG(FATAL) << "Unexpected reply type " << reply->type; - } - /* Return the task. If it is not NULL, then it must be freed by the caller. */ - return task; -} - -void redis_result_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_STRING) - << "Unexpected reply type " << reply->type << " in " - << "redis_result_table_lookup_callback"; - /* Parse the task from the reply. */ - TaskID result_id = TaskID::nil(); - bool is_put = false; - if (reply->type == REDIS_REPLY_STRING) { - auto message = flatbuffers::GetRoot(reply->str); - result_id = from_flatbuf(*message->task_id()); - is_put = message->is_put(); - } - - /* Call the done callback if there is one. */ - result_table_lookup_callback done_callback = - (result_table_lookup_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(callback_data->id, result_id, is_put, - callback_data->user_context); - } - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_result_table_lookup(TableCallbackData *callback_data) { - RAY_CHECK(callback_data); - DBHandle *db = callback_data->db_handle; - ObjectID id = callback_data->id; - redisAsyncContext *context = get_redis_context(db, id); - int status = - redisAsyncCommand(context, redis_result_table_lookup_callback, - (void *) callback_data->timer_id, - "RAY.RESULT_TABLE_LOOKUP %b", id.data(), sizeof(id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "Error in result table lookup"); - } -} - -DBClient redis_db_client_table_get(DBHandle *db, - const unsigned char *client_id, - size_t client_id_len) { - redisReply *reply = - (redisReply *) redisCommand(db->sync_context, "HGETALL %s%b", - DB_CLIENT_PREFIX, client_id, client_id_len); - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements > 0); - DBClient db_client; - int num_fields = 0; - /* Parse the fields into a DBClient. */ - for (size_t j = 0; j < reply->elements; j = j + 2) { - const char *key = reply->element[j]->str; - const char *value = reply->element[j + 1]->str; - if (strcmp(key, "ray_client_id") == 0) { - memcpy(db_client.id.mutable_data(), value, sizeof(db_client.id)); - num_fields++; - } else if (strcmp(key, "client_type") == 0) { - db_client.client_type = std::string(value); - num_fields++; - } else if (strcmp(key, "manager_address") == 0) { - db_client.manager_address = std::string(value); - num_fields++; - } else if (strcmp(key, "deleted") == 0) { - bool is_deleted = atoi(value); - db_client.is_alive = !is_deleted; - num_fields++; - } - } - freeReplyObject(reply); - /* The client ID, type, and whether it is deleted are all - * mandatory fields. Auxiliary address is optional. */ - RAY_CHECK(num_fields >= 3); - return db_client; -} - -void redis_cache_set_db_client(DBHandle *db, DBClient client) { - db->db_client_cache[client.id] = client; -} - -/** - * Get an entry from the plasma manager table in redis. - * - * @param db The database handle. - * @param index The index of the plasma manager. - * @return The IP address and port of the manager. - */ -DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id) { - auto it = db->db_client_cache.find(db_client_id); - if (it == db->db_client_cache.end()) { - DBClient db_client = redis_db_client_table_get(db, db_client_id.data(), - sizeof(db_client_id)); - db->db_client_cache[db_client_id] = db_client; - it = db->db_client_cache.find(db_client_id); - } - return it->second; -} - -void redis_object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - RAY_LOG(DEBUG) << "Object table lookup callback"; - RAY_CHECK(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_ARRAY); - - object_table_lookup_done_callback done_callback = - (object_table_lookup_done_callback) callback_data->done_callback; - - ObjectID obj_id = callback_data->id; - - /* Parse the Redis reply. */ - if (reply->type == REDIS_REPLY_NIL) { - /* The object entry did not exist. */ - if (done_callback) { - done_callback(obj_id, true, std::vector(), - callback_data->user_context); - } - } else if (reply->type == REDIS_REPLY_ARRAY) { - /* Extract the manager IDs from the response into a vector. */ - std::vector manager_ids; - - for (size_t j = 0; j < reply->elements; ++j) { - RAY_CHECK(reply->element[j]->type == REDIS_REPLY_STRING); - DBClientID manager_id; - memcpy(manager_id.mutable_data(), reply->element[j]->str, - sizeof(manager_id)); - manager_ids.push_back(manager_id); - } - - if (done_callback) { - done_callback(obj_id, false, manager_ids, callback_data->user_context); - } - } else { - RAY_LOG(FATAL) << "Unexpected reply type from object table lookup."; - } - - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void object_table_redis_subscribe_to_notifications_callback( - redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Replies to the SUBSCRIBE command have 3 elements. There are two - * possibilities. Either the reply is the initial acknowledgment of the - * subscribe command, or it is a message. If it is the initial acknowledgment, - * then - * - reply->element[0]->str is "subscribe" - * - reply->element[1]->str is the name of the channel - * - reply->emement[2]->str is null. - * If it is an actual message, then - * - reply->element[0]->str is "message" - * - reply->element[1]->str is the name of the channel - * - reply->emement[2]->str is the contents of the message. - */ - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Object table subscribe to notifications callback, message" - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* We received an object notification. Parse the payload. */ - auto message = flatbuffers::GetRoot( - reply->element[2]->str); - /* Extract the object ID. */ - ObjectID obj_id = from_flatbuf(*message->object_id()); - /* Extract the data size. */ - int64_t data_size = message->object_size(); - int manager_count = message->manager_ids()->size(); - - /* Extract the manager IDs from the response into a vector. */ - std::vector manager_ids; - for (int i = 0; i < manager_count; ++i) { - DBClientID manager_id = from_flatbuf(*message->manager_ids()->Get(i)); - manager_ids.push_back(manager_id); - } - - /* Call the subscribe callback. */ - ObjectTableSubscribeData *data = - (ObjectTableSubscribeData *) callback_data->data->Get(); - if (data->object_available_callback) { - data->object_available_callback(obj_id, data_size, manager_ids, - data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - /* Call the done callback if there is one. This code path should only be - * used in the tests. */ - if (callback_data->done_callback != NULL) { - object_table_lookup_done_callback done_callback = - (object_table_lookup_done_callback) callback_data->done_callback; - done_callback(ray::UniqueID::nil(), false, std::vector(), - callback_data->user_context); - } - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - } else { - RAY_LOG(FATAL) << "Unexpected reply type from object table subscribe to " - << "notifications."; - } -} - -void redis_object_table_subscribe_to_notifications( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - /* The object channel prefix must match the value defined in - * src/common/redismodule/ray_redis_module.cc. */ - const char *object_channel_prefix = "OC:"; - const char *object_channel_bcast = "BCAST"; - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - int status = REDIS_OK; - /* Subscribe to notifications from the object table. This uses the client ID - * as the channel name so this channel is specific to this client. - * TODO(rkn): - * The channel name should probably be the client ID with some prefix. */ - RAY_CHECK(callback_data->data->Get() != NULL) - << "Object table subscribe data passed as NULL."; - if (((ObjectTableSubscribeData *) (callback_data->data->Get())) - ->subscribe_all) { - /* Subscribe to the object broadcast channel. */ - status = redisAsyncCommand( - db->subscribe_contexts[i], - object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%s", - object_channel_prefix, object_channel_bcast); - } else { - status = redisAsyncCommand( - db->subscribe_contexts[i], - object_table_redis_subscribe_to_notifications_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b", - object_channel_prefix, db->client.data(), sizeof(db->client)); - } - - if ((status == REDIS_ERR) || db->subscribe_contexts[i]->err) { - LOG_REDIS_DEBUG(db->subscribe_contexts[i], - "error in redis_object_table_subscribe_to_notifications"); - } - } -} - -void redis_object_table_request_notifications_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - /* Do some minimal checking. */ - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_object_table_request_notifications( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ObjectTableRequestNotificationsData *request_data = - (ObjectTableRequestNotificationsData *) callback_data->data->Get(); - int num_object_ids = request_data->num_object_ids; - ObjectID *object_ids = request_data->object_ids; - - for (int i = 0; i < num_object_ids; ++i) { - redisAsyncContext *context = get_redis_context(db, object_ids[i]); - - /* Create the arguments for the Redis command. */ - int num_args = 1 + 1 + 1; - const char **argv = (const char **) malloc(sizeof(char *) * num_args); - size_t *argvlen = (size_t *) malloc(sizeof(size_t) * num_args); - /* Set the command name argument. */ - argv[0] = "RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS"; - argvlen[0] = strlen(argv[0]); - /* Set the client ID argument. */ - argv[1] = (char *) db->client.data(); - argvlen[1] = sizeof(db->client); - /* Set the object ID arguments. */ - argv[2] = (char *) object_ids[i].data(); - argvlen[2] = sizeof(object_ids[i]); - - int status = redisAsyncCommandArgv( - context, redis_object_table_request_notifications_callback, - (void *) callback_data->timer_id, num_args, argv, argvlen); - free(argv); - free(argvlen); - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, - "error in redis_object_table_subscribe_to_notifications"); - } - } -} - -/* - * ==== task_table callbacks ==== - */ - -void redis_task_table_get_task_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Parse the task from the reply. */ - Task *task = parse_and_construct_task_from_redis_reply(reply); - /* Call the done callback if there is one. */ - task_table_get_callback done_callback = - (task_table_get_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(task, callback_data->user_context); - } - /* Free the task if it is not NULL. */ - if (task != NULL) { - Task_free(task); - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_get_task(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - RAY_CHECK(callback_data->data->Get() == NULL); - TaskID task_id = callback_data->id; - - redisAsyncContext *context = get_redis_context(db, task_id); - - int status = redisAsyncCommand(context, redis_task_table_get_task_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_GET %b", task_id.data(), - sizeof(task_id)); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_get_task"); - } -} - -void redis_task_table_add_task_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - // If no subscribers received the message, call the failure callback. The - // caller should decide whether to retry the add. NOTE(swang): The caller - // should check whether the receiving subscriber is still alive in the - // db_client table before retrying the add. - if (reply->type == REDIS_REPLY_ERROR && - strcmp(reply->str, "No subscribers received message.") == 0) { - RAY_LOG(WARNING) << "No subscribers received the task_table_add message."; - if (callback_data->retry.fail_callback != NULL) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } - } else { - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " - << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_add_task(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - Task *task = (Task *) callback_data->data->Get(); - RAY_CHECK(task != NULL) << "NULL task passed to redis_task_table_add_task."; - - TaskID task_id = Task_task_id(task); - DBClientID local_scheduler_id = Task_local_scheduler(task); - redisAsyncContext *context = get_redis_context(db, task_id); - int state = static_cast(Task_state(task)); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = CreateTaskExecutionDependencies( - fbb, to_flatbuf(fbb, execution_spec->ExecutionDependencies())); - fbb.Finish(execution_dependencies); - - int status = redisAsyncCommand( - context, redis_task_table_add_task_callback, - (void *) callback_data->timer_id, "RAY.TASK_TABLE_ADD %b %d %b %b %d %b", - task_id.data(), sizeof(task_id), state, local_scheduler_id.data(), - sizeof(local_scheduler_id), fbb.GetBufferPointer(), - (size_t) fbb.GetSize(), - static_cast(execution_spec->SpillbackCount()), spec, - execution_spec->SpecSize()); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_add_task"); - } -} - -void redis_task_table_update_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - // If no subscribers received the message, call the failure callback. The - // caller should decide whether to retry the update. NOTE(swang): Retrying a - // task table update can race with the liveness monitor. Do not retry the - // update unless the caller is sure that the receiving subscriber is still - // alive in the db_client table. - if (reply->type == REDIS_REPLY_ERROR) { - RAY_LOG(WARNING) << "task_table_update failed with " << reply->str; - if (callback_data->retry.fail_callback != NULL) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } else { - RAY_LOG(FATAL) << "task_table_update failed and no fail_callback is set"; - } - } else { - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Call the done callback if there is one. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - } - - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_update(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - Task *task = (Task *) callback_data->data->Get(); - RAY_CHECK(task != NULL) << "NULL task passed to redis_task_table_update."; - - TaskID task_id = Task_task_id(task); - redisAsyncContext *context = get_redis_context(db, task_id); - DBClientID local_scheduler_id = Task_local_scheduler(task); - int state = static_cast(Task_state(task)); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = CreateTaskExecutionDependencies( - fbb, to_flatbuf(fbb, execution_spec->ExecutionDependencies())); - fbb.Finish(execution_dependencies); - - int status = redisAsyncCommand( - context, redis_task_table_update_callback, - (void *) callback_data->timer_id, "RAY.TASK_TABLE_UPDATE %b %d %b %b %d", - task_id.data(), sizeof(task_id), state, local_scheduler_id.data(), - sizeof(local_scheduler_id), fbb.GetBufferPointer(), - (size_t) fbb.GetSize(), - static_cast(execution_spec->SpillbackCount())); - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_update"); - } -} - -void redis_task_table_test_and_update_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* Parse the task from the reply. */ - Task *task = parse_and_construct_task_from_redis_reply(reply); - if (task == NULL) { - /* A NULL task means that the task was not in the task table. NOTE(swang): - * For normal tasks, this is not expected behavior, but actor tasks may be - * delayed when added to the task table if they are submitted to a local - * scheduler before it receives the notification that maps the actor to a - * local scheduler. */ - RAY_LOG(ERROR) << "No task found during task_table_test_and_update for " - << "task with ID " << callback_data->id; - return; - } - /* Determine whether the update happened. */ - auto message = flatbuffers::GetRoot(reply->str); - bool updated = message->updated(); - - /* Call the done callback if there is one. */ - task_table_test_and_update_callback done_callback = - (task_table_test_and_update_callback) callback_data->done_callback; - if (done_callback != NULL) { - done_callback(task, callback_data->user_context, updated); - } - /* Free the task if it is not NULL. */ - if (task != NULL) { - Task_free(task); - } - /* Clean up timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_task_table_test_and_update(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - TaskID task_id = callback_data->id; - redisAsyncContext *context = get_redis_context(db, task_id); - TaskTableTestAndUpdateData *update_data = - (TaskTableTestAndUpdateData *) callback_data->data->Get(); - - int status; - /* If the test local scheduler ID is NIL, then ignore it. */ - if (update_data->test_local_scheduler_id.is_nil()) { - status = redisAsyncCommand( - context, redis_task_table_test_and_update_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.data(), - sizeof(task_id), update_data->test_state_bitmask, - update_data->update_state, update_data->local_scheduler_id.data(), - sizeof(update_data->local_scheduler_id)); - } else { - status = redisAsyncCommand( - context, redis_task_table_test_and_update_callback, - (void *) callback_data->timer_id, - "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b %b", task_id.data(), - sizeof(task_id), update_data->test_state_bitmask, - update_data->update_state, update_data->local_scheduler_id.data(), - sizeof(update_data->local_scheduler_id), - update_data->test_local_scheduler_id.data(), - sizeof(update_data->test_local_scheduler_id)); - } - - if ((status == REDIS_ERR) || context->err) { - LOG_REDIS_DEBUG(context, "error in redis_task_table_test_and_update"); - } -} - -void redis_task_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - /* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to - * PSUBSCRIBE. */ - RAY_CHECK(reply->elements == 3 || reply->elements == 4) - << "reply->elements is " << reply->elements; - /* The first element is the message type and the last entry is the payload. - * The middle one or middle two elements describe the channel that was - * published on. */ - redisReply *message_type = reply->element[0]; - redisReply *payload = reply->element[reply->elements - 1]; - if (strcmp(message_type->str, "message") == 0 || - strcmp(message_type->str, "pmessage") == 0) { - /* Handle a task table event. Parse the payload and call the callback. */ - auto message = flatbuffers::GetRoot(payload->str); - /* Extract the scheduling state. */ - TaskStatus state = static_cast(message->state()); - /* Extract the local scheduler ID. */ - DBClientID local_scheduler_id = - from_flatbuf(*message->local_scheduler_id()); - /* Extract the execution dependencies. */ - auto execution_dependencies = - flatbuffers::GetRoot( - message->execution_dependencies()->data()); - /* Extract the task spec. */ - TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); - int64_t task_spec_size = message->task_spec()->size(); - /* Extract the spillback information. */ - int spillback_count = message->spillback_count(); - /* Create a task. */ - /* Allocate the task execution spec on the stack and use it to construct - * the task. - */ - TaskExecutionSpec execution_spec( - from_flatbuf(*execution_dependencies->execution_dependencies()), spec, - task_spec_size, spillback_count); - Task *task = Task_alloc(execution_spec, state, local_scheduler_id); - - /* Call the subscribe callback if there is one. */ - TaskTableSubscribeData *data = - (TaskTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback != NULL) { - data->subscribe_callback(task, data->subscribe_context); - } - Task_free(task); - } else if (strcmp(message_type->str, "subscribe") == 0 || - strcmp(message_type->str, "psubscribe") == 0) { - /* If this condition is true, we got the initial message that acknowledged - * the subscription. */ - if (callback_data->done_callback != NULL) { - task_table_done_callback done_callback = - (task_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - /* Note that we do not destroy the callback data yet because the - * subscription callback needs this data. */ - remove_timer_callback(db->loop, callback_data); - } else { - RAY_LOG(FATAL) << "Unexpected reply type from task table subscribe. " - << "Message type is " << message_type->str; - } -} - -void redis_task_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - TaskTableSubscribeData *data = - (TaskTableSubscribeData *) callback_data->data->Get(); - /* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in - * sync with that file. */ - const char *TASK_CHANNEL_PREFIX = "TT:"; - /* In the new code path, subscriptions currently go through the - * primary redis shard. */ - for (auto subscribe_context : db->subscribe_contexts) { - int status; - if (data->local_scheduler_id.is_nil()) { - /* TODO(swang): Implement the state_filter by translating the bitmask into - * a Redis key-matching pattern. */ - status = redisAsyncCommand( - subscribe_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d", - TASK_CHANNEL_PREFIX, data->state_filter); - } else { - DBClientID local_scheduler_id = data->local_scheduler_id; - status = redisAsyncCommand( - subscribe_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d", - TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.data(), - sizeof(local_scheduler_id), data->state_filter); - } - if ((status == REDIS_ERR) || subscribe_context->err) { - LOG_REDIS_DEBUG(subscribe_context, "error in redis_task_table_subscribe"); - } - } -} - -/* - * ==== db client table callbacks ==== - */ - -void redis_db_client_table_remove_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Call the done callback if there is one. */ - db_client_table_done_callback done_callback = - (db_client_table_done_callback) callback_data->done_callback; - if (done_callback) { - done_callback(callback_data->id, callback_data->user_context); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_db_client_table_remove(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = - redisAsyncCommand(db->context, redis_db_client_table_remove_callback, - (void *) callback_data->timer_id, "RAY.DISCONNECT %b", - callback_data->id.data(), sizeof(callback_data->id)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, "error in db_client_table_remove"); - } -} - -void redis_db_client_table_scan(DBHandle *db, - std::vector &db_clients) { - /* TODO(swang): Integrate this functionality with the Ray Redis module. To do - * this, we need the KEYS or SCAN command in Redis modules. */ - /* Get all the database client keys. */ - redisReply *reply = (redisReply *) redisCommand(db->sync_context, "KEYS %s*", - DB_CLIENT_PREFIX); - if (reply->type == REDIS_REPLY_NIL) { - return; - } - /* Get all the database client information. */ - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - for (size_t i = 0; i < reply->elements; ++i) { - /* Strip the database client table prefix. */ - unsigned char *key = (unsigned char *) reply->element[i]->str; - key += strlen(DB_CLIENT_PREFIX); - size_t key_len = reply->element[i]->len; - key_len -= strlen(DB_CLIENT_PREFIX); - /* Get the database client's information. */ - DBClient db_client = redis_db_client_table_get(db, key, key_len); - db_clients.push_back(db_client); - } - freeReplyObject(reply); -} - -void redis_db_client_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements > 2); - /* First entry is message type, then possibly the regex we psubscribed to, - * then topic, then payload. */ - redisReply *payload = reply->element[reply->elements - 1]; - /* If this condition is true, we got the initial message that acknowledged the - * subscription. */ - if (payload->str == NULL) { - if (callback_data->done_callback) { - db_client_table_done_callback done_callback = - (db_client_table_done_callback) callback_data->done_callback; - done_callback(callback_data->id, callback_data->user_context); - } - /* Note that we do not destroy the callback data yet because the - * subscription callback needs this data. */ - remove_timer_callback(db->loop, callback_data); - - /* Get the current db client table entries, in case we missed notifications - * before the initial subscription. This must be done before we process any - * notifications from the subscription channel, so that we don't readd an - * entry that has already been deleted. */ - std::vector db_clients; - redis_db_client_table_scan(db, db_clients); - /* Call the subscription callback for all entries that we missed. */ - DBClientTableSubscribeData *data = - (DBClientTableSubscribeData *) callback_data->data->Get(); - for (auto db_client : db_clients) { - data->subscribe_callback(&db_client, data->subscribe_context); - } - return; - } - /* Otherwise, parse the payload and call the callback. */ - auto message = - flatbuffers::GetRoot(payload->str); - - /* Parse the client type and auxiliary address from the response. If there is - * only client type, then the update was a delete. */ - DBClient db_client; - db_client.id = from_flatbuf(*message->db_client_id()); - db_client.client_type = std::string(message->client_type()->data()); - db_client.manager_address = std::string(message->manager_address()->data()); - db_client.is_alive = message->is_insertion(); - - /* Call the subscription callback. */ - DBClientTableSubscribeData *data = - (DBClientTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(&db_client, data->subscribe_context); - } -} - -void redis_db_client_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_db_client_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE db_clients"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in db_client_table_register_callback"); - } -} - -void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Local scheduler table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* Handle a local scheduler heartbeat. Parse the payload and call the - * subscribe callback. */ - auto message = - flatbuffers::GetRoot(reply->element[2]->str); - - /* Extract the client ID. */ - DBClientID client_id = from_flatbuf(*message->db_client_id()); - /* Extract the fields of the local scheduler info struct. */ - LocalSchedulerInfo info; - if (message->is_dead()) { - /* If the local scheduler is dead, then ignore all other fields in the - * message. */ - info.is_dead = true; - } else { - /* If the local scheduler is alive, collect load information. */ - info.is_dead = false; - info.total_num_workers = message->total_num_workers(); - info.task_queue_length = message->task_queue_length(); - info.available_workers = message->available_workers(); - - info.static_resources = map_from_flatbuf(*message->static_resources()); - info.dynamic_resources = map_from_flatbuf(*message->dynamic_resources()); - } - - /* Call the subscribe callback. */ - LocalSchedulerTableSubscribeData *data = - (LocalSchedulerTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(client_id, info, data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from local scheduler subscribe."; - } -} - -void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_local_scheduler_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE local_schedulers"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_local_scheduler_table_subscribe"); - } -} - -void redis_local_scheduler_table_send_info_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - LocalSchedulerTableSendInfoData *data = - (LocalSchedulerTableSendInfoData *) callback_data->data->Get(); - - int64_t size = data->size; - uint8_t *flatbuffer_data = data->flatbuffer_data; - - int status = redisAsyncCommand( - db->context, redis_local_scheduler_table_send_info_callback, - (void *) callback_data->timer_id, "PUBLISH local_schedulers %b", - flatbuffer_data, size); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_local_scheduler_table_send_info"); - } -} - -void redis_local_scheduler_table_disconnect(DBHandle *db) { - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - std::unordered_map empty_resource_map; - /* Most of the flatbuffer message fields don't matter here. Only the - * db_client_id and the is_dead field matter. */ - auto message = CreateLocalSchedulerInfoMessage( - fbb, to_flatbuf(fbb, db->client), 0, 0, 0, - map_to_flatbuf(fbb, empty_resource_map), - map_to_flatbuf(fbb, empty_resource_map), true); - fbb.Finish(message); - - redisReply *reply = (redisReply *) redisCommand( - db->sync_context, "PUBLISH local_schedulers %b", fbb.GetBufferPointer(), - (size_t) fbb.GetSize()); - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - freeReplyObject(reply); -} - -void redis_driver_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Driver table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - /* Handle a driver heartbeat. Parse the payload and call the subscribe - * callback. */ - auto message = - flatbuffers::GetRoot(reply->element[2]->str); - /* Extract the client ID. */ - WorkerID driver_id = from_flatbuf(*message->driver_id()); - - /* Call the subscribe callback. */ - DriverTableSubscribeData *data = - (DriverTableSubscribeData *) callback_data->data->Get(); - if (data->subscribe_callback) { - data->subscribe_callback(driver_id, data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from driver subscribe."; - } -} - -void redis_driver_table_subscribe(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_driver_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE driver_deaths"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_driver_table_subscribe"); - } -} - -void redis_driver_table_send_driver_death_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - /* At the very least, the local scheduler that publishes this message should - * also receive it. */ - RAY_CHECK(reply->integer >= 1); - - RAY_CHECK(callback_data->done_callback == NULL); - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_driver_table_send_driver_death(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - WorkerID driver_id = callback_data->id; - - /* Create a flatbuffer object to serialize and publish. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto message = CreateDriverTableMessage(fbb, to_flatbuf(fbb, driver_id)); - fbb.Finish(message); - - int status = redisAsyncCommand( - db->context, redis_driver_table_send_driver_death_callback, - (void *) callback_data->timer_id, "PUBLISH driver_deaths %b", - fbb.GetBufferPointer(), (size_t) fbb.GetSize()); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_driver_table_send_driver_death"); - } -} - -void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - /* NOTE(swang): We purposefully do not provide a callback, leaving the table - * operation and timer active. This allows us to send a new heartbeat every - * heartbeat_timeout_milliseconds without having to allocate and deallocate - * memory for callback data each time. */ - int status = redisAsyncCommand( - db->context, NULL, (void *) callback_data->timer_id, - "PUBLISH plasma_managers %b", db->client.data(), sizeof(db->client)); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_plasma_manager_send_heartbeat"); - } - /* Clean up the timer and callback. */ - destroy_timer_callback(db->loop, callback_data); -} - -void redis_publish_actor_creation_notification_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - RAY_LOG(DEBUG) << reply->integer << " subscribers received this publish."; - // At the very least, the local scheduler that publishes this message should - // also receive it. - RAY_CHECK(reply->integer >= 1); - - RAY_CHECK(callback_data->done_callback == NULL); - // Clean up the timer and callback. - destroy_timer_callback(db->loop, callback_data); -} - -void redis_publish_actor_creation_notification( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - - ActorCreationNotificationData *data = - (ActorCreationNotificationData *) callback_data->data->Get(); - - int status = redisAsyncCommand( - db->context, redis_publish_actor_creation_notification_callback, - (void *) callback_data->timer_id, "PUBLISH actor_notifications %b", - &data->flatbuffer_data[0], data->size); - if ((status == REDIS_ERR) || db->context->err) { - LOG_REDIS_DEBUG(db->context, - "error in redis_publish_actor_creation_notification"); - } -} - -void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - - redisReply *reply = (redisReply *) r; - RAY_CHECK(reply->type == REDIS_REPLY_ARRAY); - RAY_CHECK(reply->elements == 3); - redisReply *message_type = reply->element[0]; - RAY_LOG(DEBUG) << "Local scheduler table subscribe callback, message " - << message_type->str; - - if (strcmp(message_type->str, "message") == 0) { - // Handle an actor notification message. Parse the payload and call the - // subscribe callback. - redisReply *payload = reply->element[2]; - ActorNotificationTableSubscribeData *data = - (ActorNotificationTableSubscribeData *) callback_data->data->Get(); - - auto message = - flatbuffers::GetRoot(payload->str); - ActorID actor_id = from_flatbuf(*message->actor_id()); - WorkerID driver_id = from_flatbuf(*message->driver_id()); - DBClientID local_scheduler_id = - from_flatbuf(*message->local_scheduler_id()); - - if (data->subscribe_callback) { - data->subscribe_callback(actor_id, driver_id, local_scheduler_id, - data->subscribe_context); - } - } else if (strcmp(message_type->str, "subscribe") == 0) { - /* The reply for the initial SUBSCRIBE command. */ - RAY_CHECK(callback_data->done_callback == NULL); - /* If the initial SUBSCRIBE was successful, clean up the timer, but don't - * destroy the callback data. */ - remove_timer_callback(db->loop, callback_data); - - } else { - RAY_LOG(FATAL) << "Unexpected reply type from actor notification " - << "subscribe."; - } -} - -void redis_actor_notification_table_subscribe( - TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - int status = redisAsyncCommand( - db->subscribe_context, redis_actor_notification_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE actor_notifications"); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, - "error in redis_actor_notification_table_subscribe"); - } -} - -void redis_actor_table_mark_removed(DBHandle *db, ActorID actor_id) { - int status = - redisAsyncCommand(db->context, NULL, NULL, "HSET Actor:%b removed \"1\"", - actor_id.data(), sizeof(actor_id)); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->context, "error in redis_actor_table_mark_removed"); - } -} - -void redis_push_error_rpush_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - /* The reply should be the length of the errors list after our RPUSH. */ - RAY_CHECK(reply->type == REDIS_REPLY_INTEGER); - destroy_timer_callback(db->loop, callback_data); -} - -void redis_push_error_hmset_callback(redisAsyncContext *c, - void *r, - void *privdata) { - REDIS_CALLBACK_HEADER(db, callback_data, r); - redisReply *reply = (redisReply *) r; - - /* Make sure we were able to add the error information. */ - RAY_CHECK(reply->type != REDIS_REPLY_ERROR) << "reply->str is " << reply->str; - RAY_CHECK(strcmp(reply->str, "OK") == 0) << "reply->str is " << reply->str; - - /* Add the error to this driver's list of errors. */ - ErrorInfo *info = (ErrorInfo *) callback_data->data->Get(); - int status = redisAsyncCommand( - db->context, redis_push_error_rpush_callback, - (void *) callback_data->timer_id, "RPUSH ErrorKeys Error:%b:%b", - info->driver_id.data(), sizeof(info->driver_id), info->error_key.data(), - sizeof(info->error_key)); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error rpush"); - } -} - -void redis_push_error(TableCallbackData *callback_data) { - DBHandle *db = callback_data->db_handle; - ErrorInfo *info = (ErrorInfo *) callback_data->data->Get(); - RAY_CHECK(info->error_type < ErrorIndex::MAX && - info->error_type >= ErrorIndex::OBJECT_HASH_MISMATCH); - /// Look up the error type. - const char *error_type = error_types[static_cast(info->error_type)]; - - /* Set the error information. */ - int status = redisAsyncCommand( - db->context, redis_push_error_hmset_callback, - (void *) callback_data->timer_id, - "HMSET Error:%b:%b type %s message %b data %b", info->driver_id.data(), - sizeof(info->driver_id), info->error_key.data(), sizeof(info->error_key), - error_type, info->error_message, info->size, "None", strlen("None")); - if ((status == REDIS_ERR) || db->subscribe_context->err) { - LOG_REDIS_DEBUG(db->subscribe_context, "error in redis_push_error hmset"); - } -} - -DBClientID get_db_client_id(DBHandle *db) { - RAY_CHECK(db != NULL); - return db->client; -} diff --git a/src/common/state/redis.h b/src/common/state/redis.h deleted file mode 100644 index 164069740d3e..000000000000 --- a/src/common/state/redis.h +++ /dev/null @@ -1,356 +0,0 @@ -#ifndef REDIS_H -#define REDIS_H - -#include - -#include "db.h" -#include "db_client_table.h" -#include "object_table.h" -#include "task_table.h" - -#include "hiredis/hiredis.h" -#include "hiredis/async.h" - -#define LOG_REDIS_ERROR(context, M, ...) \ - RAY_LOG(ERROR) << "Redis error " << context->err << " " << context->errstr \ - << "; " << M - -#define LOG_REDIS_DEBUG(context, M, ...) \ - RAY_LOG(DEBUG) << "Redis error " << context->err << " " << context->errstr \ - << "; " << M; - -struct DBHandle { - /** String that identifies this client type. */ - char *client_type; - /** Unique ID for this client. */ - DBClientID client; - /** Primary redis context for all non-subscribe connections. This is used for - * the database client table, heartbeats, and errors that should be pushed to - * the driver. */ - redisAsyncContext *context; - /** Primary redis context for "subscribe" communication. A separate context - * is needed for this communication (see - * https://github.com/redis/hiredis/issues/55). This is used for the - * database client table, heartbeats, and errors that should be pushed to - * the driver. */ - redisAsyncContext *subscribe_context; - /** Redis contexts for shards for all non-subscribe connections. All requests - * to the object table, task table, and event table should be directed here. - * The correct shard can be retrieved using get_redis_context below. */ - std::vector contexts; - /** Redis contexts for shards for "subscribe" communication. All requests - * to the object table, task table, and event table should be directed here. - * The correct shard can be retrieved using get_redis_context below. */ - std::vector subscribe_contexts; - /** The event loop this global state store connection is part of. */ - event_loop *loop; - /** Index of the database connection in the event loop */ - int64_t db_index; - /** Cache for the IP addresses of db clients. This is an unordered map mapping - * client IDs to addresses. */ - std::unordered_map db_client_cache; - /** Redis context for synchronous connections. This should only be used very - * rarely, it is not asynchronous. */ - redisContext *sync_context; -}; - -/** - * Get the Redis asynchronous context responsible for non-subscription - * communication for the given UniqueID. - * - * @param db The database handle. - * @param id The ID whose location we are querying for. - * @return The redisAsyncContext responsible for the given ID. - */ -redisAsyncContext *get_redis_context(DBHandle *db, UniqueID id); - -/** - * Get the Redis asynchronous context responsible for subscription - * communication for the given UniqueID. - * - * @param db The database handle. - * @param id The ID whose location we are querying for. - * @return The redisAsyncContext responsible for the given ID. - */ -redisAsyncContext *get_redis_subscribe_context(DBHandle *db, UniqueID id); - -/** - * Get a list of Redis shard IP addresses from the primary shard. - * - * @param context A Redis context connected to the primary shard. - * @param db_shards_addresses The IP addresses for the shards registered - * with the primary shard will be added to this vector. - * @param db_shards_ports The IP ports for the shards registered with the - * primary shard will be added to this vector, in the same order as - * db_shards_addresses. - */ -void get_redis_shards(redisContext *context, - std::vector &db_shards_addresses, - std::vector &db_shards_ports); - -void redis_cache_set_db_client(DBHandle *db, DBClient client); - -DBClient redis_cache_get_db_client(DBHandle *db, DBClientID db_client_id); - -void redis_object_table_get_entry(redisAsyncContext *c, - void *r, - void *privdata); - -void object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/* - * ==== Redis object table functions ==== - */ - -/** - * Lookup object table entry in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_lookup(TableCallbackData *callback_data); - -/** - * Add a location entry to the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_add(TableCallbackData *callback_data); - -/** - * Remove a location entry from the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_remove(TableCallbackData *callback_data); - -/** - * Create a client-specific channel for receiving notifications from the object - * table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_subscribe_to_notifications( - TableCallbackData *callback_data); - -/** - * Request notifications about when certain objects become available. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_object_table_request_notifications(TableCallbackData *callback_data); - -/** - * Add a new object to the object table in redis. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_result_table_add(TableCallbackData *callback_data); - -/** - * Lookup the task that created the object in redis. The result is the task ID. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_result_table_lookup(TableCallbackData *callback_data); - -/** - * Callback invoked when the reply from the object table lookup command is - * received. - * - * @param c Redis context. - * @param r Reply. - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_object_table_lookup_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/* - * ==== Redis task table function ===== - */ - -/** - * Get a task table entry, including the task spec and the task's scheduling - * information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_get_task(TableCallbackData *callback_data); - -/** - * Add a task table entry with a new task spec and the task's scheduling - * information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_add_task(TableCallbackData *callback_data); - -/** - * Update a task table entry with the task's scheduling information. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_update(TableCallbackData *callback_data); - -/** - * Update a task table entry with the task's scheduling information, if the - * task's current scheduling information matches the test value. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_test_and_update(TableCallbackData *callback_data); - -/** - * Callback invoked when the reply from the task push command is received. - * - * @param c Redis context. - * @param r Reply (not used). - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_task_table_publish_push_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/** - * Callback invoked when the reply from the task publish command is received. - * - * @param c Redis context. - * @param r Reply (not used). - * @param privdata Data associated to the callback. - * @return Void. - */ -void redis_task_table_publish_publish_callback(redisAsyncContext *c, - void *r, - void *privdata); - -/** - * Subscribe to updates of the task table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_task_table_subscribe(TableCallbackData *callback_data); - -/** - * Remove a client from the db clients table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_db_client_table_remove(TableCallbackData *callback_data); - -/** - * Subscribe to updates from the db client table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_db_client_table_subscribe(TableCallbackData *callback_data); - -/** - * Subscribe to updates from the local scheduler table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data); - -/** - * Publish an update to the local scheduler table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_local_scheduler_table_send_info(TableCallbackData *callback_data); - -/** - * Synchronously publish a null update to the local scheduler table signifying - * that we are about to exit. - * - * @param db The database handle of the dying local scheduler. - * @return Void. - */ -void redis_local_scheduler_table_disconnect(DBHandle *db); - -/** - * Subscribe to updates from the driver table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_driver_table_subscribe(TableCallbackData *callback_data); - -/** - * Publish an update to the driver table. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_driver_table_send_driver_death(TableCallbackData *callback_data); - -void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data); - -/** - * Marks an actor as removed. This prevents the actor from being resurrected. - * - * @param db The database handle. - * @param actor_id The actor id to mark as removed. - * @return Void. - */ -void redis_actor_table_mark_removed(DBHandle *db, ActorID actor_id); - -/// Publish an actor creation notification. -/// -/// \param callback_data Data structure containing redis connection and timeout -/// information. -/// \return Void. -void redis_publish_actor_creation_notification( - TableCallbackData *callback_data); - -/** - * Subscribe to updates about newly created actors. - * - * @param callback_data Data structure containing redis connection and timeout - * information. - * @return Void. - */ -void redis_actor_notification_table_subscribe(TableCallbackData *callback_data); - -void redis_object_info_subscribe(TableCallbackData *callback_data); - -void redis_push_error(TableCallbackData *callback_data); - -#endif /* REDIS_H */ diff --git a/src/common/state/table.cc b/src/common/state/table.cc deleted file mode 100644 index 8269c2b1e739..000000000000 --- a/src/common/state/table.cc +++ /dev/null @@ -1,200 +0,0 @@ -#include "table.h" - -#include -#include -#include "redis.h" - -BaseCallbackData::BaseCallbackData(void *data) { - data_ = data; -} - -BaseCallbackData::~BaseCallbackData(void) {} - -void *BaseCallbackData::Get(void) { - return data_; -} - -CommonCallbackData::CommonCallbackData(void *data) : BaseCallbackData(data) {} - -CommonCallbackData::~CommonCallbackData(void) { - free(data_); -} - -TaskCallbackData::TaskCallbackData(Task *task_data) - : BaseCallbackData(task_data) {} - -TaskCallbackData::~TaskCallbackData(void) { - Task *task = (Task *) data_; - Task_free(task); -} - -/* The default behavior is to retry every ten seconds forever. */ -static const RetryInfo default_retry = {.num_retries = -1, - .timeout = 10000, - .fail_callback = NULL}; - -static int64_t callback_data_id = 0; - -TableCallbackData *init_table_callback(DBHandle *db_handle, - UniqueID id, - const char *label, - OWNER BaseCallbackData *data, - RetryInfo *retry, - table_done_callback done_callback, - table_retry_callback retry_callback, - void *user_context) { - RAY_CHECK(db_handle); - RAY_CHECK(db_handle->loop); - RAY_CHECK(data); - /* If no retry info is provided, use the default retry info. */ - if (retry == NULL) { - retry = (RetryInfo *) &default_retry; - } - RAY_CHECK(retry); - /* Allocate and initialize callback data structure for object table */ - TableCallbackData *callback_data = - (TableCallbackData *) malloc(sizeof(TableCallbackData)); - RAY_CHECK(callback_data != NULL) << "Memory allocation error!"; - callback_data->id = id; - callback_data->label = label; - callback_data->retry = *retry; - callback_data->done_callback = done_callback; - callback_data->retry_callback = retry_callback; - callback_data->data = data; - callback_data->requests_info = NULL; - callback_data->user_context = user_context; - callback_data->db_handle = db_handle; - /* TODO(ekl) set a retry timer once we've figured out the retry conditions - * and have a solution to the O(n^2) ae timers issue. For now, use a dummy - * timer id to uniquely id this callback. */ - callback_data->timer_id = callback_data_id++; - outstanding_callbacks_add(callback_data); - - RAY_LOG(DEBUG) << "Initializing table command " << callback_data->label - << " with timer ID " << callback_data->timer_id; - callback_data->retry_callback(callback_data); - - return callback_data; -} - -void destroy_timer_callback(event_loop *loop, - TableCallbackData *callback_data) { - /* This is commented out because we no longer add timers to the event loop for - * each Redis command. */ - // event_loop_remove_timer(loop, callback_data->timer_id); - destroy_table_callback(callback_data); -} - -void remove_timer_callback(event_loop *loop, TableCallbackData *callback_data) { - /* This is commented out because we no longer add timers to the event loop for - * each Redis command. */ - // event_loop_remove_timer(loop, callback_data->timer_id); -} - -void destroy_table_callback(TableCallbackData *callback_data) { - RAY_CHECK(callback_data != NULL); - - if (callback_data->requests_info) - free(callback_data->requests_info); - - RAY_CHECK(callback_data->data != NULL); - delete callback_data->data; - callback_data->data = NULL; - - outstanding_callbacks_remove(callback_data); - - /* Timer is removed via EVENT_LOOP_TIMER_DONE in the timeout callback. */ - free(callback_data); -} - -int64_t table_timeout_handler(event_loop *loop, - int64_t timer_id, - void *user_context) { - RAY_CHECK(loop != NULL); - RAY_CHECK(user_context != NULL); - TableCallbackData *callback_data = (TableCallbackData *) user_context; - - RAY_CHECK(callback_data->retry.num_retries >= 0 || - callback_data->retry.num_retries == -1); - RAY_LOG(WARNING) << "retrying operation " << callback_data->label - << ", retry_count = " << callback_data->retry.num_retries; - - if (callback_data->retry.num_retries == 0) { - /* We didn't get a response from the database after exhausting all retries; - * let user know, cleanup the state, and remove the timer. */ - RAY_LOG(WARNING) << "Table command " << callback_data->label - << " with timer ID " << timer_id << " failed"; - if (callback_data->retry.fail_callback) { - callback_data->retry.fail_callback(callback_data->id, - callback_data->user_context, - callback_data->data->Get()); - } - destroy_table_callback(callback_data); - return EVENT_LOOP_TIMER_DONE; - } - - /* Decrement retry count and try again. We use -1 to indicate infinite - * retries. */ - if (callback_data->retry.num_retries != -1) { - callback_data->retry.num_retries--; - } - callback_data->retry_callback(callback_data); - return callback_data->retry.timeout; -} - -/** - * Unordered map maintaining the outstanding callbacks. - * - * This unordered map is used to handle the following case: - * - a table command is issued with an associated callback and a callback data - * structure; - * - the last timeout associated to this command expires, as a result the - * callback data structure is freed; - * - a reply arrives, but now the callback data structure is gone, so we have - * to ignore this reply; - * - * This unordered map enables us to ignore such replies. The operations on the - * unordered map are as follows. - * - * When we issue a table command and a timeout event to wait for the reply, we - * add a new entry to the unordered map that is keyed by the ID of the timer. - * Note that table commands must have unique timer IDs, which are assigned by - * the Redis ae event loop. - * - * When we receive the reply, we check whether the callback still exists in - * this unordered map, and if not we just ignore the reply. If the callback does - * exist, the reply receiver is responsible for removing the timer and the - * entry associated to the callback, or else the timeout handler will continue - * firing. - * - * When the last timeout associated to the command expires we remove the entry - * associated to the callback. - */ -static std::unordered_map outstanding_callbacks; - -void outstanding_callbacks_add(TableCallbackData *callback_data) { - outstanding_callbacks[callback_data->timer_id] = callback_data; -} - -TableCallbackData *outstanding_callbacks_find(int64_t key) { - auto it = outstanding_callbacks.find(key); - if (it != outstanding_callbacks.end()) { - return it->second; - } - return NULL; -} - -void outstanding_callbacks_remove(TableCallbackData *callback_data) { - outstanding_callbacks.erase(callback_data->timer_id); -} - -void destroy_outstanding_callbacks(event_loop *loop) { - /* We have to be careful because destroy_timer_callback modifies - * outstanding_callbacks in place */ - auto it = outstanding_callbacks.begin(); - while (it != outstanding_callbacks.end()) { - auto next_it = std::next(it, 1); - destroy_timer_callback(loop, it->second); - it = next_it; - } -} diff --git a/src/common/state/table.h b/src/common/state/table.h deleted file mode 100644 index 1fadcf339cef..000000000000 --- a/src/common/state/table.h +++ /dev/null @@ -1,216 +0,0 @@ -#ifndef TABLE_H -#define TABLE_H - -#include "common.h" -#include "task.h" -#include "db.h" - -typedef struct TableCallbackData TableCallbackData; - -/* An abstract class for any data passed by the user into a table operation. - * This class wraps arbitrary pointers and allows the caller to define a custom - * destructor, for data that is not allocated with malloc. */ -class BaseCallbackData { - public: - BaseCallbackData(void *data); - virtual ~BaseCallbackData(void) = 0; - - /* Return the pointer to the data. */ - void *Get(void); - - protected: - /* The pointer to the data. */ - void *data_; -}; - -/* A common class for malloc'ed data passed by the user into a table operation. - * This should ONLY be used when only a free is necessary. */ -class CommonCallbackData : public BaseCallbackData { - public: - CommonCallbackData(void *data); - ~CommonCallbackData(void); -}; - -/* A class for Task data passed by the user into a table operation. This calls - * task cleanup in the destructor. */ -class TaskCallbackData : public BaseCallbackData { - public: - TaskCallbackData(Task *task_data); - ~TaskCallbackData(void); -}; - -typedef void *table_done_callback; - -/* The callback called when the database operation hasn't completed after - * the number of retries specified for the operation. - * - * @param id The unique ID that identifies this callback. Examples include an - * object ID or task ID. - * @param user_context The state context for the callback. This is equivalent - * to the user_context field in TableCallbackData. - * @param user_data A data argument for the callback. This is equivalent to the - * data field in TableCallbackData. The user is responsible for - * freeing user_data. - */ -typedef void (*table_fail_callback)(UniqueID id, - void *user_context, - void *user_data); - -typedef void (*table_retry_callback)(TableCallbackData *callback_data); - -/** - * Data structure consolidating the retry related variables. If a NULL - * RetryInfo struct is used, the default behavior will be to retry infinitely - * many times. - */ -typedef struct { - /** Number of retries. This field will be decremented every time a retry - * occurs (unless the value is -1). If this value is -1, then there will be - * infinitely many retries. */ - int num_retries; - /** Timeout, in milliseconds. */ - uint64_t timeout; - /** The callback that will be called if there are no more retries left. */ - table_fail_callback fail_callback; -} RetryInfo; - -struct TableCallbackData { - /** ID of the entry in the table that we are going to look up, remove or add. - */ - UniqueID id; - /** A label to identify the original request for logging purposes. */ - const char *label; - /** The callback that will be called when results is returned. */ - table_done_callback done_callback; - /** The callback that will be called to initiate the next try. */ - table_retry_callback retry_callback; - /** Retry information containing the remaining number of retries, the timeout - * before the next retry, and a pointer to the failure callback. - */ - RetryInfo retry; - /** Pointer to the data that is entered into the table. This can be used to - * pass the result of the call to the callback. The callback takes ownership - * over this data and will free it. */ - BaseCallbackData *data; - /** Pointer to the data used internally to handle multiple database requests. - */ - void *requests_info; - /** User context. */ - void *user_context; - /** Handle to db. */ - DBHandle *db_handle; - /** Handle to timer. */ - int64_t timer_id; -}; - -/** - * Function to handle the timeout event. - * - * @param loop Event loop. - * @param timer_id Timer identifier. - * @param context Pointer to the callback data for the object table - * @return Timeout to reset the timer if we need to try again, or - * EVENT_LOOP_TIMER_DONE if retry_count == 0. - */ -int64_t table_timeout_handler(event_loop *loop, - int64_t timer_id, - void *context); - -/** - * Initialize the table callback and call the retry_callback for the first time. - * - * @param db_handle Database handle. - * @param id ID of the object that is looked up, added or removed. - * @param label A string label to identify the type of table request for - * logging purposes. - * @param data Data entered into the table. Shall be freed by the user. Caller - * must specify a destructor by wrapping a void *pointer in a - * BaseCallbackData class. - * @param retry Retry relevant information: retry timeout, number of remaining - * retries, and retry callback. - * @param done_callback Function to be called when database returns result. - * @param fail_callback Function to be called when number of retries is - * exhausted. - * @param user_context Context that can be provided by the user and will be - * passed on to the various callbacks. - * @return New table callback data struct. - */ -TableCallbackData *init_table_callback(DBHandle *db_handle, - UniqueID id, - const char *label, - OWNER BaseCallbackData *data, - RetryInfo *retry, - table_done_callback done_callback, - table_retry_callback retry_callback, - void *user_context); - -/** - * Destroy any state associated with the callback data. This removes all - * associated state from the outstanding callbacks unordered map and frees any - * associated memory. This does not remove any associated timer events. - * - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void destroy_table_callback(TableCallbackData *callback_data); - -/** - * Destroy all state events associated with the callback data, including memory - * and timer events. - * - * @param loop The event loop. - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void destroy_timer_callback(event_loop *loop, TableCallbackData *callback_data); - -/** - * Remove the callback timer without destroying the callback data. - * - * @param loop The event loop. - * @param callback_data The pointer to the data structure of the callback. - * @return Void. - */ -void remove_timer_callback(event_loop *loop, TableCallbackData *callback_data); - -/** - * Add an outstanding callback entry. - * - * @param callback_data The pointer to the data structure of the callback we - * want to insert. - * @return None. - */ -void outstanding_callbacks_add(TableCallbackData *callback_data); - -/** - * Find an outstanding callback entry. - * - * @param key The key for the outstanding callbacks unordered map. We use the - * timer ID assigned by the Redis ae event loop. - * @return Returns the callback data if found, NULL otherwise. - */ -TableCallbackData *outstanding_callbacks_find(int64_t key); - -/** - * Remove an outstanding callback entry. This only removes the callback entry - * from the unordered map. It does not free the entry or remove any associated - * timer events. - * - * @param callback_data The pointer to the data structure of the callback we - * want to remove. - * @return Void. - */ -void outstanding_callbacks_remove(TableCallbackData *callback_data); - -/** - * Destroy all outstanding callbacks and remove their associated timer events - * from the event loop. - * - * @param loop The event loop from which we want to remove the timer events. - * @return Void. - */ -void destroy_outstanding_callbacks(event_loop *loop); - -#endif /* TABLE_H */ diff --git a/src/common/state/task_table.cc b/src/common/state/task_table.cc deleted file mode 100644 index 514350b08353..000000000000 --- a/src/common/state/task_table.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "task_table.h" -#include "redis.h" - -#define NUM_DB_REQUESTS 2 - -void task_table_get_task(DBHandle *db_handle, - TaskID task_id, - RetryInfo *retry, - task_table_get_callback get_callback, - void *user_context) { - init_table_callback( - db_handle, task_id, __func__, new CommonCallbackData(NULL), retry, - (void *) get_callback, redis_task_table_get_task, user_context); -} - -void task_table_add_task(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, Task_task_id(task), __func__, - new TaskCallbackData(task), retry, - (table_done_callback) done_callback, - redis_task_table_add_task, user_context); -} - -void task_table_update(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - init_table_callback(db_handle, Task_task_id(task), __func__, - new TaskCallbackData(task), retry, - (table_done_callback) done_callback, - redis_task_table_update, user_context); -} - -void task_table_test_and_update( - DBHandle *db_handle, - TaskID task_id, - DBClientID test_local_scheduler_id, - TaskStatus test_state_bitmask, - TaskStatus update_state, - RetryInfo *retry, - task_table_test_and_update_callback done_callback, - void *user_context) { - TaskTableTestAndUpdateData *update_data = - (TaskTableTestAndUpdateData *) malloc(sizeof(TaskTableTestAndUpdateData)); - update_data->test_local_scheduler_id = test_local_scheduler_id; - update_data->test_state_bitmask = test_state_bitmask; - update_data->update_state = update_state; - /* Update the task entry's local scheduler with this client's ID. */ - update_data->local_scheduler_id = db_handle->client; - init_table_callback(db_handle, task_id, __func__, - new CommonCallbackData(update_data), retry, - (table_done_callback) done_callback, - redis_task_table_test_and_update, user_context); -} - -/* TODO(swang): A corresponding task_table_unsubscribe. */ -void task_table_subscribe(DBHandle *db_handle, - DBClientID local_scheduler_id, - TaskStatus state_filter, - task_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context) { - TaskTableSubscribeData *sub_data = - (TaskTableSubscribeData *) malloc(sizeof(TaskTableSubscribeData)); - sub_data->local_scheduler_id = local_scheduler_id; - sub_data->state_filter = state_filter; - sub_data->subscribe_callback = subscribe_callback; - sub_data->subscribe_context = subscribe_context; - - init_table_callback(db_handle, local_scheduler_id, __func__, - new CommonCallbackData(sub_data), retry, - (table_done_callback) done_callback, - redis_task_table_subscribe, user_context); -} diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h deleted file mode 100644 index 3884ddece893..000000000000 --- a/src/common/state/task_table.h +++ /dev/null @@ -1,190 +0,0 @@ -#ifndef task_table_H -#define task_table_H - -#include "db.h" -#include "table.h" -#include "task.h" - -/** - * The task table is a message bus that is used for communication between local - * and global schedulers (and also persisted to the state database). Here are - * examples of events that are recorded by the task table: - * - * 1) Local schedulers write to it when submitting a task to the global - * scheduler. - * 2) The global scheduler subscribes to updates to the task table to get tasks - * submitted by local schedulers. - * 3) The global scheduler writes to it when assigning a task to a local - * scheduler. - * 4) Local schedulers subscribe to updates to the task table to get tasks - * assigned to them by the global scheduler. - * 5) Local schedulers write to it when a task finishes execution. - */ - -/* Callback called when a task table write operation completes. */ -typedef void (*task_table_done_callback)(TaskID task_id, void *user_context); - -/* Callback called when a task table read operation completes. If the task ID - * was not in the task table, then the task pointer will be NULL. */ -typedef void (*task_table_get_callback)(Task *task, void *user_context); - -/* Callback called when a task table test-and-update operation completes. If - * the task ID was not in the task table, then the task pointer will be NULL. - * If the update succeeded, the updated field will be set to true. */ -typedef void (*task_table_test_and_update_callback)(Task *task, - void *user_context, - bool updated); - -/** - * Get a task's entry from the task table. - * - * @param db_handle Database handle. - * @param task_id The ID of the task we want to look up. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_get_task(DBHandle *db, - TaskID task_id, - RetryInfo *retry, - task_table_get_callback get_callback, - void *user_context); - -/** - * Add a task entry, including task spec and scheduling information, to the task - * table. This will overwrite any task already in the task table with the same - * task ID. - * - * @param db_handle Database handle. - * @param task The task entry to add to the table. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_add_task(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/* - * ==== Publish the task table ==== - */ - -/** - * Update a task's scheduling information in the task table. This assumes that - * the task spec already exists in the task table entry. - * - * @param db_handle Database handle. - * @param task The task entry to add to the table. The task spec in the entry is - * ignored. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_update(DBHandle *db_handle, - OWNER Task *task, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/** - * Update a task's scheduling information in the task table, if the current - * value matches the given test value. If the update succeeds, it also updates - * the task entry's local scheduler ID with the ID of the client who called - * this function. This assumes that the task spec already exists in the task - * table entry. - * - * @param db_handle Database handle. - * @param task_id The task ID of the task entry to update. - * @param test_local_scheduler_id The local scheduler ID to test the current - * local scheduler ID against. If not NIL_ID, and if the current local - * scheduler ID does not match it, then the update will not happen. - * @param test_state_bitmask The bitmask to apply to the task entry's current - * scheduling state. The update happens if and only if the current - * scheduling state AND-ed with the bitmask is greater than 0 and the - * local scheduler ID test passes. - * @param update_state The value to update the task entry's scheduling state - * with, if the current state matches test_state_bitmask. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_test_and_update( - DBHandle *db_handle, - TaskID task_id, - DBClientID test_local_scheduler_id, - TaskStatus test_state_bitmask, - TaskStatus update_state, - RetryInfo *retry, - task_table_test_and_update_callback done_callback, - void *user_context); - -/* Data that is needed to test and set the task's scheduling state. */ -typedef struct { - /** The value to test the current local scheduler ID against. This field is - * ignored if equal to NIL_ID. */ - DBClientID test_local_scheduler_id; - TaskStatus test_state_bitmask; - TaskStatus update_state; - DBClientID local_scheduler_id; -} TaskTableTestAndUpdateData; - -/* - * ==== Subscribing to the task table ==== - */ - -/* Callback for subscribing to the task table. */ -typedef void (*task_table_subscribe_callback)(Task *task, void *user_context); - -/** - * Register a callback for a task event. An event is any update of a task in - * the task table, produced by task_table_add_task or task_table_add_task. - * Events include changes to the task's scheduling state or changes to the - * task's local scheduler ID. - * - * @param db_handle Database handle. - * @param subscribe_callback Callback that will be called when the task table is - * updated. - * @param subscribe_context Context that will be passed into the - * subscribe_callback. - * @param local_scheduler_id The db_client_id of the local scheduler whose - * events we want to listen to. If you want to subscribe to updates from - * all local schedulers, pass in NIL_ID. - * @param state_filter Events we want to listen to. Can have values from the - * enum "scheduling_state" in task.h. - * TODO(pcm): Make it possible to combine these using flags like - * TASK_STATUS_WAITING | TASK_STATUS_SCHEDULED. - * @param retry Information about retrying the request to the database. - * @param done_callback Function to be called when database returns result. - * @param user_context Data that will be passed to done_callback and - * fail_callback. - * @return Void. - */ -void task_table_subscribe(DBHandle *db_handle, - DBClientID local_scheduler_id, - TaskStatus state_filter, - task_table_subscribe_callback subscribe_callback, - void *subscribe_context, - RetryInfo *retry, - task_table_done_callback done_callback, - void *user_context); - -/* Data that is needed to register task table subscribe callbacks with the state - * database. */ -typedef struct { - DBClientID local_scheduler_id; - TaskStatus state_filter; - task_table_subscribe_callback subscribe_callback; - void *subscribe_context; -} TaskTableSubscribeData; - -#endif /* task_table_H */ diff --git a/src/common/task.cc b/src/common/task.cc deleted file mode 100644 index 60110fe22543..000000000000 --- a/src/common/task.cc +++ /dev/null @@ -1,606 +0,0 @@ -#include - -#include "common_protocol.h" - -#include "task.h" - -extern "C" { -#include "sha256.h" -} - -ObjectID task_compute_return_id(TaskID task_id, int64_t return_index) { - /* Here, return_indices need to be >= 0, so we can use negative - * indices for put. */ - RAY_DCHECK(return_index >= 0); - /* TODO(rkn): This line requires object and task IDs to be the same size. */ - ObjectID return_id = task_id; - int64_t *first_bytes = (int64_t *) &return_id; - /* XOR the first bytes of the object ID with the return index. We add one so - * the first return ID is not the same as the task ID. */ - *first_bytes = *first_bytes ^ (return_index + 1); - return return_id; -} - -ObjectID task_compute_put_id(TaskID task_id, int64_t put_index) { - RAY_DCHECK(put_index >= 0); - /* TODO(pcm): This line requires object and task IDs to be the same size. */ - ObjectID put_id = task_id; - int64_t *first_bytes = (int64_t *) &put_id; - /* XOR the first bytes of the object ID with the return index. We add one so - * the first return ID is not the same as the task ID. */ - *first_bytes = *first_bytes ^ (-put_index - 1); - return put_id; -} - -class TaskBuilder { - public: - void Start(UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorHandleID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns) { - driver_id_ = driver_id; - parent_task_id_ = parent_task_id; - parent_counter_ = parent_counter; - actor_creation_id_ = actor_creation_id; - actor_creation_dummy_object_id_ = actor_creation_dummy_object_id; - actor_id_ = actor_id; - actor_handle_id_ = actor_handle_id; - actor_counter_ = actor_counter; - is_actor_checkpoint_method_ = is_actor_checkpoint_method; - function_id_ = function_id; - num_returns_ = num_returns; - - /* Compute hashes. */ - sha256_init(&ctx); - sha256_update(&ctx, (BYTE *) &driver_id, sizeof(driver_id)); - sha256_update(&ctx, (BYTE *) &parent_task_id, sizeof(parent_task_id)); - sha256_update(&ctx, (BYTE *) &parent_counter, sizeof(parent_counter)); - sha256_update(&ctx, (BYTE *) &actor_creation_id, sizeof(actor_creation_id)); - sha256_update(&ctx, (BYTE *) &actor_creation_dummy_object_id, - sizeof(actor_creation_dummy_object_id)); - sha256_update(&ctx, (BYTE *) &actor_id, sizeof(actor_id)); - sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter)); - sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method, - sizeof(is_actor_checkpoint_method)); - sha256_update(&ctx, (BYTE *) &function_id, sizeof(function_id)); - } - - void NextReferenceArgument(ObjectID object_ids[], int num_object_ids) { - args.push_back( - CreateArg(fbb, to_flatbuf(fbb, &object_ids[0], num_object_ids))); - sha256_update(&ctx, (BYTE *) &object_ids[0], - sizeof(object_ids[0]) * num_object_ids); - } - - void NextValueArgument(uint8_t *value, int64_t length) { - auto arg = fbb.CreateString((const char *) value, length); - auto empty_ids = fbb.CreateVectorOfStrings({}); - args.push_back(CreateArg(fbb, empty_ids, arg)); - sha256_update(&ctx, (BYTE *) value, length); - } - - void SetRequiredResource(const std::string &resource_name, double value) { - RAY_CHECK(resource_map_.count(resource_name) == 0); - resource_map_[resource_name] = value; - } - - uint8_t *Finish(int64_t *size) { - /* Add arguments. */ - auto arguments = fbb.CreateVector(args); - /* Update hash. */ - BYTE buff[DIGEST_SIZE]; - sha256_final(&ctx, buff); - TaskID task_id; - RAY_CHECK(sizeof(task_id) <= DIGEST_SIZE); - memcpy(&task_id, buff, sizeof(task_id)); - /* Add return object IDs. */ - std::vector> returns; - for (int64_t i = 0; i < num_returns_; i++) { - ObjectID return_id = task_compute_return_id(task_id, i); - returns.push_back(to_flatbuf(fbb, return_id)); - } - /* Create TaskInfo. */ - auto message = CreateTaskInfo( - fbb, to_flatbuf(fbb, driver_id_), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, parent_task_id_), parent_counter_, - to_flatbuf(fbb, actor_creation_id_), - to_flatbuf(fbb, actor_creation_dummy_object_id_), - to_flatbuf(fbb, actor_id_), to_flatbuf(fbb, actor_handle_id_), - actor_counter_, is_actor_checkpoint_method_, - to_flatbuf(fbb, function_id_), arguments, fbb.CreateVector(returns), - map_to_flatbuf(fbb, resource_map_)); - /* Finish the TaskInfo. */ - fbb.Finish(message); - *size = fbb.GetSize(); - uint8_t *result = (uint8_t *) malloc(*size); - memcpy(result, fbb.GetBufferPointer(), *size); - fbb.Clear(); - args.clear(); - resource_map_.clear(); - return result; - } - - private: - flatbuffers::FlatBufferBuilder fbb; - std::vector> args; - SHA256_CTX ctx; - - /* Data for the builder. */ - UniqueID driver_id_; - TaskID parent_task_id_; - int64_t parent_counter_; - ActorID actor_creation_id_; - ObjectID actor_creation_dummy_object_id_; - ActorID actor_id_; - ActorID actor_handle_id_; - int64_t actor_counter_; - bool is_actor_checkpoint_method_; - FunctionID function_id_; - int64_t num_returns_; - std::unordered_map resource_map_; -}; - -TaskBuilder *make_task_builder(void) { - return new TaskBuilder(); -} - -void free_task_builder(TaskBuilder *builder) { - delete builder; -} - -bool TaskID_equal(TaskID first_id, TaskID second_id) { - return first_id == second_id; -} - -bool TaskID_is_nil(TaskID id) { - return id.is_nil(); -} - -bool ActorID_equal(ActorID first_id, ActorID second_id) { - return first_id == second_id; -} - -bool FunctionID_equal(FunctionID first_id, FunctionID second_id) { - return first_id == second_id; -} - -bool FunctionID_is_nil(FunctionID id) { - return id.is_nil(); -} - -/* Functions for building tasks. */ - -void TaskSpec_start_construct(TaskBuilder *builder, - UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns) { - builder->Start(driver_id, parent_task_id, parent_counter, actor_creation_id, - actor_creation_dummy_object_id, actor_id, actor_handle_id, - actor_counter, is_actor_checkpoint_method, function_id, - num_returns); -} - -TaskSpec *TaskSpec_finish_construct(TaskBuilder *builder, int64_t *size) { - return reinterpret_cast(builder->Finish(size)); -} - -void TaskSpec_args_add_ref(TaskBuilder *builder, - ObjectID object_ids[], - int num_object_ids) { - builder->NextReferenceArgument(&object_ids[0], num_object_ids); -} - -void TaskSpec_args_add_val(TaskBuilder *builder, - uint8_t *value, - int64_t length) { - builder->NextValueArgument(value, length); -} - -void TaskSpec_set_required_resource(TaskBuilder *builder, - const std::string &resource_name, - double value) { - builder->SetRequiredResource(resource_name, value); -} - -/* Functions for reading tasks. */ - -TaskID TaskSpec_task_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->task_id()); -} - -FunctionID TaskSpec_function(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->function_id()); -} - -ActorID TaskSpec_actor_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_id()); -} - -ActorID TaskSpec_actor_handle_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_handle_id()); -} - -bool TaskSpec_is_actor_task(TaskSpec *spec) { - return !TaskSpec_actor_id(spec).is_nil(); -} - -ActorID TaskSpec_actor_creation_id(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_creation_id()); -} - -ObjectID TaskSpec_actor_creation_dummy_object_id(TaskSpec *spec) { - RAY_CHECK(spec); - // The task must be an actor method. - RAY_CHECK(TaskSpec_is_actor_task(spec)); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->actor_creation_dummy_object_id()); -} - -bool TaskSpec_is_actor_creation_task(TaskSpec *spec) { - return !TaskSpec_actor_creation_id(spec).is_nil(); -} - -int64_t TaskSpec_actor_counter(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return std::abs(message->actor_counter()); -} - -bool TaskSpec_is_actor_checkpoint_method(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->is_actor_checkpoint_method(); -} - -ObjectID TaskSpec_actor_dummy_object(TaskSpec *spec) { - RAY_CHECK(TaskSpec_is_actor_task(spec)); - /* The last return value for actor tasks is the dummy object that - * represents that this task has completed execution. */ - int64_t num_returns = TaskSpec_num_returns(spec); - return TaskSpec_return(spec, num_returns - 1); -} - -UniqueID TaskSpec_driver_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->driver_id()); -} - -TaskID TaskSpec_parent_task_id(const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->parent_task_id()); -} - -int64_t TaskSpec_parent_counter(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->parent_counter(); -} - -int64_t TaskSpec_num_args(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->size(); -} - -int64_t TaskSpec_num_args_by_ref(TaskSpec *spec) { - int64_t num_args = TaskSpec_num_args(spec); - int64_t num_args_by_ref = 0; - for (int64_t i = 0; i < num_args; i++) { - if (TaskSpec_arg_by_ref(spec, i)) { - num_args_by_ref++; - } - } - return num_args_by_ref; -} - -int TaskSpec_arg_id_count(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - auto ids = message->args()->Get(arg_index)->object_ids(); - if (ids == nullptr) { - return 0; - } else { - return ids->size(); - } -} - -ObjectID TaskSpec_arg_id(TaskSpec *spec, int64_t arg_index, int64_t id_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf( - *message->args()->Get(arg_index)->object_ids()->Get(id_index)); -} - -const uint8_t *TaskSpec_arg_val(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return (uint8_t *) message->args()->Get(arg_index)->data()->c_str(); -} - -int64_t TaskSpec_arg_length(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->Get(arg_index)->data()->size(); -} - -int64_t TaskSpec_num_returns(TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->returns()->size(); -} - -bool TaskSpec_arg_by_ref(TaskSpec *spec, int64_t arg_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return message->args()->Get(arg_index)->object_ids()->size() != 0; -} - -ObjectID TaskSpec_return(TaskSpec *spec, int64_t return_index) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return from_flatbuf(*message->returns()->Get(return_index)); -} - -double TaskSpec_get_required_resource(const TaskSpec *spec, - const std::string &resource_name) { - // This is a bit ugly. However it shouldn't be much of a performance issue - // because there shouldn't be many distinct resources in a single task spec. - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - for (size_t i = 0; i < message->required_resources()->size(); i++) { - const ResourcePair *resource_pair = message->required_resources()->Get(i); - if (string_from_flatbuf(*resource_pair->key()) == resource_name) { - return resource_pair->value(); - } - } - return 0; -} - -const std::unordered_map TaskSpec_get_required_resources( - const TaskSpec *spec) { - RAY_CHECK(spec); - auto message = flatbuffers::GetRoot(spec); - return map_from_flatbuf(*message->required_resources()); -} - -TaskSpec *TaskSpec_copy(TaskSpec *spec, int64_t task_spec_size) { - TaskSpec *copy = (TaskSpec *) malloc(task_spec_size); - memcpy(copy, spec, task_spec_size); - return copy; -} - -void TaskSpec_free(TaskSpec *spec) { - free(spec); -} - -TaskExecutionSpec::TaskExecutionSpec( - const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size, - int spillback_count) - : execution_dependencies_(execution_dependencies), - task_spec_size_(task_spec_size), - last_timestamp_(0), - spillback_count_(spillback_count) { - TaskSpec *spec_copy = new TaskSpec[task_spec_size_]; - memcpy(spec_copy, spec, task_spec_size); - spec_ = std::unique_ptr(spec_copy); -} - -TaskExecutionSpec::TaskExecutionSpec( - const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size) - : TaskExecutionSpec(execution_dependencies, spec, task_spec_size, 0) {} - -TaskExecutionSpec::TaskExecutionSpec(TaskExecutionSpec *other) - : execution_dependencies_(other->execution_dependencies_), - task_spec_size_(other->task_spec_size_), - last_timestamp_(other->last_timestamp_), - spillback_count_(other->spillback_count_) { - TaskSpec *spec_copy = new TaskSpec[task_spec_size_]; - memcpy(spec_copy, other->spec_.get(), task_spec_size_); - spec_ = std::unique_ptr(spec_copy); -} - -const std::vector &TaskExecutionSpec::ExecutionDependencies() const { - return execution_dependencies_; -} - -void TaskExecutionSpec::SetExecutionDependencies( - const std::vector &dependencies) { - execution_dependencies_ = dependencies; -} - -int64_t TaskExecutionSpec::SpecSize() const { - return task_spec_size_; -} - -int TaskExecutionSpec::SpillbackCount() const { - return spillback_count_; -} - -void TaskExecutionSpec::IncrementSpillbackCount() { - ++spillback_count_; -} - -int64_t TaskExecutionSpec::LastTimeStamp() const { - return last_timestamp_; -} - -void TaskExecutionSpec::SetLastTimeStamp(int64_t new_timestamp) { - last_timestamp_ = new_timestamp; -} - -TaskSpec *TaskExecutionSpec::Spec() const { - return spec_.get(); -} - -int64_t TaskExecutionSpec::NumDependencies() const { - TaskSpec *spec = Spec(); - int64_t num_dependencies = TaskSpec_num_args(spec); - num_dependencies += execution_dependencies_.size(); - return num_dependencies; -} - -int TaskExecutionSpec::DependencyIdCount(int64_t dependency_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. Find the total number of task arguments so - * that we can index into the correct list. */ - int64_t num_args = TaskSpec_num_args(spec); - if (dependency_index < num_args) { - /* Index into the task arguments. */ - return TaskSpec_arg_id_count(spec, dependency_index); - } else { - /* Index into the execution dependencies. */ - dependency_index -= num_args; - RAY_CHECK((size_t) dependency_index < execution_dependencies_.size()); - /* All elements in the execution dependency list have exactly one ID. */ - return 1; - } -} - -ObjectID TaskExecutionSpec::DependencyId(int64_t dependency_index, - int64_t id_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. Find the total number of task arguments so - * that we can index into the correct list. */ - int64_t num_args = TaskSpec_num_args(spec); - if (dependency_index < num_args) { - /* Index into the task arguments. */ - return TaskSpec_arg_id(spec, dependency_index, id_index); - } else { - /* Index into the execution dependencies. */ - dependency_index -= num_args; - RAY_CHECK((size_t) dependency_index < execution_dependencies_.size()); - return execution_dependencies_[dependency_index]; - } -} - -bool TaskExecutionSpec::DependsOn(ObjectID object_id) const { - // Iterate through the task arguments to see if it contains object_id. - TaskSpec *spec = Spec(); - int64_t num_args = TaskSpec_num_args(spec); - for (int i = 0; i < num_args; ++i) { - int count = TaskSpec_arg_id_count(spec, i); - for (int j = 0; j < count; j++) { - ObjectID arg_id = TaskSpec_arg_id(spec, i, j); - if (arg_id == object_id) { - return true; - } - } - } - // Iterate through the execution dependencies to see if it contains object_id. - for (auto dependency_id : execution_dependencies_) { - if (dependency_id == object_id) { - return true; - } - } - // The requested object ID was not a task argument or an execution dependency. - // This task is not dependent on it. - return false; -} - -bool TaskExecutionSpec::IsStaticDependency(int64_t dependency_index) const { - TaskSpec *spec = Spec(); - /* The first dependencies are the arguments of the task itself, followed by - * the execution dependencies. If the requested dependency index is a task - * argument, then it is a task dependency. */ - int64_t num_args = TaskSpec_num_args(spec); - return (dependency_index < num_args); -} - -/* TASK INSTANCES */ - -Task *Task_alloc(const TaskSpec *spec, - int64_t task_spec_size, - TaskStatus state, - DBClientID local_scheduler_id, - const std::vector &execution_dependencies) { - Task *result = new Task(); - auto execution_spec = - new TaskExecutionSpec(execution_dependencies, spec, task_spec_size); - result->execution_spec = std::unique_ptr(execution_spec); - result->state = state; - result->local_scheduler_id = local_scheduler_id; - return result; -} - -Task *Task_alloc(TaskExecutionSpec &execution_spec, - TaskStatus state, - DBClientID local_scheduler_id) { - Task *result = new Task(); - result->execution_spec = std::unique_ptr( - new TaskExecutionSpec(&execution_spec)); - result->state = state; - result->local_scheduler_id = local_scheduler_id; - return result; -} - -Task *Task_copy(Task *other) { - return Task_alloc(*Task_task_execution_spec(other), other->state, - other->local_scheduler_id); -} - -int64_t Task_size(Task *task_arg) { - return sizeof(Task) - sizeof(TaskSpec) + task_arg->execution_spec->SpecSize(); -} - -TaskStatus Task_state(Task *task) { - return task->state; -} - -void Task_set_state(Task *task, TaskStatus state) { - task->state = state; -} - -DBClientID Task_local_scheduler(Task *task) { - return task->local_scheduler_id; -} - -void Task_set_local_scheduler(Task *task, DBClientID local_scheduler_id) { - task->local_scheduler_id = local_scheduler_id; -} - -TaskExecutionSpec *Task_task_execution_spec(Task *task) { - return task->execution_spec.get(); -} - -TaskID Task_task_id(Task *task) { - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - return TaskSpec_task_id(spec); -} - -void Task_free(Task *task) { - delete task; -} diff --git a/src/common/task.h b/src/common/task.h deleted file mode 100644 index 3984cfdd5119..000000000000 --- a/src/common/task.h +++ /dev/null @@ -1,609 +0,0 @@ -#ifndef TASK_H -#define TASK_H - -#include - -#include -#include -#include "common.h" - -#include - -#include "format/common_generated.h" - -using namespace ray; - -typedef char TaskSpec; - -class TaskExecutionSpec { - public: - TaskExecutionSpec(const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size); - TaskExecutionSpec(const std::vector &execution_dependencies, - const TaskSpec *spec, - int64_t task_spec_size, - int spillback_count); - TaskExecutionSpec(TaskExecutionSpec *execution_spec); - - /// Get the task's execution dependencies. - /// - /// @return A vector of object IDs representing this task's execution - /// dependencies. - const std::vector &ExecutionDependencies() const; - - /// Set the task's execution dependencies. - /// - /// @param dependencies The value to set the execution dependencies to. - /// @return Void. - void SetExecutionDependencies(const std::vector &dependencies); - - /// Get the task spec size. - /// - /// @return The size of the immutable task spec. - int64_t SpecSize() const; - - /// Get the task's spillback count, which tracks the number of times - /// this task was spilled back from local to the global scheduler. - /// - /// @return The spillback count for this task. - int SpillbackCount() const; - - /// Increment the spillback count for this task. - /// - /// @return Void. - void IncrementSpillbackCount(); - - /// Get the task's last timestamp. - /// - /// @return The timestamp when this task was last received for scheduling. - int64_t LastTimeStamp() const; - - /// Set the task's last timestamp to the specified value. - /// - /// @param new_timestamp The new timestamp in millisecond to set the task's - /// time stamp to. Tracks the last time this task entered a local - /// scheduler. - /// @return Void. - void SetLastTimeStamp(int64_t new_timestamp); - - /// Get the task spec. - /// - /// @return A pointer to the immutable task spec. - TaskSpec *Spec() const; - - /// Get the number of dependencies. This comprises the immutable task - /// arguments and the mutable execution dependencies. - /// - /// @return The number of dependencies. - int64_t NumDependencies() const; - - /// Get the number of object IDs at the given dependency index. - /// - /// @param dependency_index The dependency index whose object IDs to count. - /// @return The number of object IDs at the given dependency_index. - int DependencyIdCount(int64_t dependency_index) const; - - /// Get the object ID of a given dependency index. - /// - /// @param dependency_index The index at which we should look up the object - /// ID. - /// @param id_index The index of the object ID. - ObjectID DependencyId(int64_t dependency_index, int64_t id_index) const; - - /// Compute whether the task is dependent on an object ID. - /// - /// @param object_id The object ID that the task may be dependent on. - /// @return bool This returns true if the task is dependent on the given - /// object ID and false otherwise. - bool DependsOn(ObjectID object_id) const; - - /// Returns whether the given dependency index is a static dependency (an - /// argument of the immutable task). - /// - /// @param dependency_index The requested dependency index. - /// @return bool This returns true if the requested dependency index is - /// immutable (an argument of the task). - bool IsStaticDependency(int64_t dependency_index) const; - - private: - /** A list of object IDs representing this task's dependencies at execution - * time. */ - std::vector execution_dependencies_; - /** The size of the task specification for this task. */ - int64_t task_spec_size_; - /** Last time this task was received for scheduling. */ - int64_t last_timestamp_; - /** Number of times this task was spilled back by local schedulers. */ - int spillback_count_; - /** The task specification for this task. */ - std::unique_ptr spec_; -}; - -class TaskBuilder; - -typedef UniqueID FunctionID; - -/** The task ID is a deterministic hash of the function ID that the task - * executes and the argument IDs or argument values. */ -typedef UniqueID TaskID; - -/** The actor ID is the ID of the actor that a task must run on. If the task is - * not run on an actor, then NIL_ACTOR_ID should be used. */ -typedef UniqueID ActorID; - -/** - * Compare two task IDs. - * - * @param first_id The first task ID to compare. - * @param second_id The first task ID to compare. - * @return True if the task IDs are the same and false otherwise. - */ -bool TaskID_equal(TaskID first_id, TaskID second_id); - -/** - * Compare a task ID to the nil ID. - * - * @param id The task ID to compare to nil. - * @return True if the task ID is equal to nil. - */ -bool TaskID_is_nil(TaskID id); - -/** - * Compare two actor IDs. - * - * @param first_id The first actor ID to compare. - * @param second_id The first actor ID to compare. - * @return True if the actor IDs are the same and false otherwise. - */ -bool ActorID_equal(ActorID first_id, ActorID second_id); - -/** - * Compare two function IDs. - * - * @param first_id The first function ID to compare. - * @param second_id The first function ID to compare. - * @return True if the function IDs are the same and false otherwise. - */ -bool FunctionID_equal(FunctionID first_id, FunctionID second_id); - -/** - * Compare a function ID to the nil ID. - * - * @param id The function ID to compare to nil. - * @return True if the function ID is equal to nil. - */ -bool FunctionID_is_nil(FunctionID id); - -/* Construct and modify task specifications. */ - -TaskBuilder *make_task_builder(void); - -void free_task_builder(TaskBuilder *builder); - -/** - * Begin constructing a task_spec. After this is called, the arguments must be - * added to the task_spec and then finish_construct_task_spec must be called. - * - * @param driver_id The ID of the driver whose job is responsible for the - * creation of this task. - * @param parent_task_id The task ID of the task that submitted this task. - * @param parent_counter A counter indicating how many tasks were submitted by - * the parent task prior to this one. - * @param actor_creation_id The actor creation ID of this task. - * @param actor_creation_dummy_object_id The dummy object for the corresponding - * actor creation task, assuming this is an actor method. - * @param actor_id The ID of the actor that this task is for. If it is not an - * actor task, then this if NIL_ACTOR_ID. - * @param actor_handle_id The ID of the actor handle that this task was - * submitted through. If it is not an actor task, or if this is the - * original handle, then this is NIL_ACTOR_ID. - * @param actor_counter A counter indicating how many tasks have been submitted - * to the same actor before this one. - * @param is_actor_checkpoint_method True if this is an actor checkpoint method - * and false otherwise. - * @param function_id The function ID of the function to execute in this task. - * @param num_args The number of arguments that this task has. - * @param num_returns The number of return values that this task has. - * @param args_value_size The total size in bytes of the arguments to this task - ignoring object ID arguments. - * @return The partially constructed task_spec. - */ -void TaskSpec_start_construct(TaskBuilder *B, - UniqueID driver_id, - TaskID parent_task_id, - int64_t parent_counter, - ActorID actor_creation_id, - ObjectID actor_creation_dummy_object_id, - ActorID actor_id, - ActorHandleID actor_handle_id, - int64_t actor_counter, - bool is_actor_checkpoint_method, - FunctionID function_id, - int64_t num_returns); - -/** - * Finish constructing a task_spec. This computes the task ID and the object IDs - * of the task return values. This must be called after all of the arguments - * have been added to the task. - * - * @param spec The task spec whose ID and return object IDs should be computed. - * @return Void. - */ -TaskSpec *TaskSpec_finish_construct(TaskBuilder *builder, int64_t *size); - -/** - * Return the function ID of the task. - * - * @param spec The task_spec in question. - * @return The function ID of the function to execute in this task. - */ -FunctionID TaskSpec_function(TaskSpec *spec); - -/** - * Return the actor ID of the task. - * - * @param spec The task_spec in question. - * @return The actor ID of the actor the task is part of. - */ -ActorID TaskSpec_actor_id(TaskSpec *spec); - -/** - * Return the actor handle ID of the task. - * - * @param spec The task_spec in question. - * @return The ID of the actor handle that the task was submitted through. - */ -ActorID TaskSpec_actor_handle_id(TaskSpec *spec); - -/** - * Return whether this task is for an actor. - * - * @param spec The task_spec in question. - * @return Whether the task is for an actor. - */ -bool TaskSpec_is_actor_task(TaskSpec *spec); - -/// Return whether this task is an actor creation task or not. -/// -/// \param spec The task_spec in question. -/// \return True if this task is an actor creation task and false otherwise. -bool TaskSpec_is_actor_creation_task(TaskSpec *spec); - -/// Return the actor creation ID of the task. The task must be an actor creation -/// task. -/// -/// \param spec The task_spec in question. -/// \return The actor creation ID if this is an actor creation task. -ActorID TaskSpec_actor_creation_id(TaskSpec *spec); - -/// Return the actor creation dummy object ID of the task. The task must be an -/// actor task. -/// -/// \param spec The task_spec in question. -/// \return The actor creation dummy object ID corresponding to this actor task. -ObjectID TaskSpec_actor_creation_dummy_object_id(TaskSpec *spec); - -/** - * Return the actor counter of the task. This starts at 0 and increments by 1 - * every time a new task is submitted to run on the actor. - * - * @param spec The task_spec in question. - * @return The actor counter of the task. - */ -int64_t TaskSpec_actor_counter(TaskSpec *spec); - -/** - * Return whether the task is a checkpoint method execution. - * - * @param spec The task_spec in question. - * @return Whether the task is a checkpoint method. - */ -bool TaskSpec_is_actor_checkpoint_method(TaskSpec *spec); - -/** - * Return an actor task's dummy return value. Dummy objects are used to - * encode an actor's state dependencies in the task graph. The dummy object - * is local if and only if the task that returned it has completed - * execution. - * - * @param spec The task_spec in question. - * @return The dummy object ID that the actor task will return. - */ -ObjectID TaskSpec_actor_dummy_object(TaskSpec *spec); - -/** - * Return the driver ID of the task. - * - * @param spec The task_spec in question. - * @return The driver ID of the task. - */ -UniqueID TaskSpec_driver_id(const TaskSpec *spec); - -/** - * Return the task ID of the parent task. - * - * @param spec The task_spec in question. - * @return The task ID of the parent task. - */ -TaskID TaskSpec_parent_task_id(const TaskSpec *spec); - -/** - * Return the task counter of the parent task. For example, this equals 5 if - * this task was the 6th task submitted by the parent task. - * - * @param spec The task_spec in question. - * @return The task counter of the parent task. - */ -int64_t TaskSpec_parent_counter(TaskSpec *spec); - -/** - * Return the task ID of the task. - * - * @param spec The task_spec in question. - * @return The task ID of the task. - */ -TaskID TaskSpec_task_id(const TaskSpec *spec); - -/** - * Get the number of arguments to this task. - * - * @param spec The task_spec in question. - * @return The number of arguments to this task. - */ -int64_t TaskSpec_num_args(TaskSpec *spec); - -/** - * Get the number of return values expected from this task. - * - * @param spec The task_spec in question. - * @return The number of return values expected from this task. - */ -int64_t TaskSpec_num_returns(TaskSpec *spec); - -/** - * Return true if this argument is passed by reference. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return True if this argument is passed by reference. - */ -bool TaskSpec_arg_by_ref(TaskSpec *spec, int64_t arg_index); - -/** - * Get number of object IDs in a given argument - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return number of object IDs in this argument - */ -int TaskSpec_arg_id_count(TaskSpec *spec, int64_t arg_index); - -/** - * Get a particular argument to this task. This assumes the argument is an - * object ID. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @param id_index The index of the object ID in this arg. - * @return The argument at that index. - */ -ObjectID TaskSpec_arg_id(TaskSpec *spec, int64_t arg_index, int64_t id_index); - -/** - * Get a particular argument to this task. This assumes the argument is a value. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return The argument at that index. - */ -const uint8_t *TaskSpec_arg_val(TaskSpec *spec, int64_t arg_index); - -/** - * Get the number of bytes in a particular argument to this task. This assumes - * the argument is a value. - * - * @param spec The task_spec in question. - * @param arg_index The index of the argument in question. - * @return The number of bytes in the argument. - */ -int64_t TaskSpec_arg_length(TaskSpec *spec, int64_t arg_index); - -/** - * Set the next task argument. Note that this API only allows you to set the - * arguments in their order of appearance. - * - * @param spec The task_spec in question. - * @param object_ids The object IDs to set the argument to. - * @param num_object_ids number of IDs in this param, usually 1. - * @return The number of task arguments that have been set before this one. This - * is only used for testing. - */ -void TaskSpec_args_add_ref(TaskBuilder *spec, - ObjectID object_ids[], - int num_object_ids); - -/** - * Set the next task argument. Note that this API only allows you to set the - * arguments in their order of appearance. - * - * @param spec The task_spec in question. - * @param The value to set the argument to. - * @param The length of the value to set the argument to. - * @return The number of task arguments that have been set before this one. This - * is only used for testing. - */ -void TaskSpec_args_add_val(TaskBuilder *builder, - uint8_t *value, - int64_t length); - -/** - * Set the value associated to a resource index. - * - * @param spec Task specification. - * @param resource_name Name of the resource in the resource vector. - * @param value Value for the resource. This can be a quantity of this resource - * this task needs or a value for an attribute this task requires. - * @return Void. - */ -void TaskSpec_set_required_resource(TaskBuilder *builder, - const std::string &resource_name, - double value); - -/** - * Get a particular return object ID of a task. - * - * @param spec The task_spec in question. - * @param return_index The index of the return object ID in question. - * @return The relevant return object ID. - */ -ObjectID TaskSpec_return(TaskSpec *data, int64_t return_index); - -/** - * Get the value associated to a resource name. - * - * @param spec Task specification. - * @param resource_name Name of the resource. - * @return How many of this resource the task needs to execute. - */ -double TaskSpec_get_required_resource(const TaskSpec *spec, - const std::string &resource_name); - -/** - * - */ -const std::unordered_map TaskSpec_get_required_resources( - const TaskSpec *spec); - -/** - * Compute the object id associated to a put call. - * - * @param task_id The task id of the parent task that called the put. - * @param put_index The number of put calls in this task so far. - * @return The object ID for the object that was put. - */ -ObjectID task_compute_put_id(TaskID task_id, int64_t put_index); - -/** - * Print the task as a humanly readable string. - * - * @param spec The task_spec in question. - * @return The humanly readable string. - */ -std::string TaskSpec_print(TaskSpec *spec); - -/** - * Create a copy of the task spec. Must be freed with TaskSpec_free after use. - * - * @param spec The task specification that will be copied. - * @param task_spec_size The size of the task specification in bytes. - * @returns Pointer to the copy of the task specification. - */ -TaskSpec *TaskSpec_copy(TaskSpec *spec, int64_t task_spec_size); - -/** - * Free a task_spec. - * - * @param The task_spec in question. - * @return Void. - */ -void TaskSpec_free(TaskSpec *spec); - -/** - * ==== Task ==== - * Contains information about a scheduled task: The task specification, the - * task scheduling state (WAITING, SCHEDULED, QUEUED, RUNNING, DONE), and which - * local scheduler the task is scheduled on. - */ - -/** The scheduling_state can be used as a flag when we are listening - * for an event, for example TASK_WAITING | TASK_SCHEDULED. */ -enum class TaskStatus : uint { - /** The task is waiting to be scheduled. */ - WAITING = 1, - /** The task has been scheduled to a node, but has not been queued yet. */ - SCHEDULED = 2, - /** The task has been queued on a node, where it will wait for its - * dependencies to become ready and a worker to become available. */ - QUEUED = 4, - /** The task is running on a worker. */ - RUNNING = 8, - /** The task is done executing. */ - DONE = 16, - /** The task was not able to finish. */ - LOST = 32, - /** The task will be submitted for reexecution. */ - RECONSTRUCTING = 64, - /** An actor task is cached at a local scheduler and is waiting for the - * corresponding actor to be created. */ - ACTOR_CACHED = 128 -}; - -inline TaskStatus operator|(const TaskStatus &a, const TaskStatus &b) { - uint c = static_cast(a) | static_cast(b); - return static_cast(c); -} - -/** A task is an execution of a task specification. It has a state of execution - * (see scheduling_state) and the ID of the local scheduler it is scheduled on - * or running on. */ - -struct Task { - /** The scheduling state of the task. */ - TaskStatus state; - /** The ID of the local scheduler involved. */ - DBClientID local_scheduler_id; - /** The execution specification for this task. */ - std::unique_ptr execution_spec; -}; - -/** - * Allocate a new task. Must be freed with free_task after use. - * - * @param spec The task spec for the new task. - * @param state The scheduling state for the new task. - * @param local_scheduler_id The ID of the local scheduler that the task is - * scheduled on, if any. - */ -Task *Task_alloc(const TaskSpec *spec, - int64_t task_spec_size, - TaskStatus state, - DBClientID local_scheduler_id, - const std::vector &execution_dependencies); - -Task *Task_alloc(TaskExecutionSpec &execution_spec, - TaskStatus state, - DBClientID local_scheduler_id); - -/** - * Create a copy of the task. Must be freed with Task_free after use. - * - * @param other The task that will be copied. - * @returns Pointer to the copy of the task. - */ -Task *Task_copy(Task *other); - -/** Size of task structure in bytes. */ -int64_t Task_size(Task *task); - -/** The scheduling state of the task. */ -TaskStatus Task_state(Task *task); - -/** Update the schedule state of the task. */ -void Task_set_state(Task *task, TaskStatus state); - -/** Local scheduler this task has been assigned to or is running on. */ -DBClientID Task_local_scheduler(Task *task); - -/** Set the local scheduler ID for this task. */ -void Task_set_local_scheduler(Task *task, DBClientID local_scheduler_id); - -TaskExecutionSpec *Task_task_execution_spec(Task *task); - -/** Task ID of this task. */ -TaskID Task_task_id(Task *task); - -/** Free this task datastructure. */ -void Task_free(Task *task); - -#endif /* TASK_H */ diff --git a/src/common/test/db_tests.cc b/src/common/test/db_tests.cc deleted file mode 100644 index 83585ca66e0f..000000000000 --- a/src/common/test/db_tests.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include "event_loop.h" -#include "test_common.h" -#include "example_task.h" -#include "net.h" -#include "state/db.h" -#include "state/db_client_table.h" -#include "state/object_table.h" -#include "state/task_table.h" -#include "state/redis.h" -#include "task.h" - -SUITE(db_tests); - -TaskBuilder *g_task_builder = NULL; - -/* Retry 10 times with an 100ms timeout. */ -const int NUM_RETRIES = 10; -const uint64_t TIMEOUT = 50; - -const char *manager_addr = "127.0.0.1"; -int manager_port1 = 12345; -int manager_port2 = 12346; -char received_addr1[16] = {0}; -int received_port1; -char received_addr2[16] = {0}; -int received_port2; - -typedef struct { int test_number; } user_context; - -const int TEST_NUMBER = 10; - -/* Test if entries have been written to the database. */ - -void lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context) { - DBHandle *db = (DBHandle *) user_context; - RAY_CHECK(manager_ids.size() == 2); - const std::vector managers = - db_client_table_get_ip_addresses(db, manager_ids); - RAY_CHECK(parse_ip_addr_port(managers.at(0).c_str(), received_addr1, - &received_port1) == 0); - RAY_CHECK(parse_ip_addr_port(managers.at(1).c_str(), received_addr2, - &received_port2) == 0); -} - -/* Entry added to database successfully. */ -void add_done_callback(ObjectID object_id, bool success, void *user_context) {} - -/* Test if we got a timeout callback if we couldn't connect database. */ -void timeout_callback(ObjectID object_id, void *context, void *user_data) { - user_context *uc = (user_context *) context; - RAY_CHECK(uc->test_number == TEST_NUMBER); -} - -int64_t timeout_handler(event_loop *loop, int64_t id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -TEST object_table_lookup_test(void) { - event_loop *loop = event_loop_create(); - /* This uses manager_port1. */ - std::vector db_connect_args1; - db_connect_args1.push_back("manager_address"); - db_connect_args1.push_back("127.0.0.1:12345"); - DBHandle *db1 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - manager_addr, db_connect_args1); - /* This uses manager_port2. */ - std::vector db_connect_args2; - db_connect_args2.push_back("manager_address"); - db_connect_args2.push_back("127.0.0.1:12346"); - DBHandle *db2 = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - manager_addr, db_connect_args2); - db_attach(db1, loop, false); - db_attach(db2, loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = timeout_callback, - }; - object_table_add(db1, id, 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, NULL); - object_table_add(db2, id, 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, NULL); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - object_table_lookup(db1, id, &retry, lookup_done_callback, db1); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - ASSERT_STR_EQ(&received_addr1[0], manager_addr); - ASSERT((received_port1 == manager_port1 && received_port2 == manager_port2) || - (received_port2 == manager_port1 && received_port1 == manager_port2)); - - db_disconnect(db1); - db_disconnect(db2); - - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - PASS(); -} - -int task_table_test_callback_called = 0; -Task *task_table_test_task; - -void task_table_test_fail_callback(UniqueID id, - void *context, - void *user_data) { - event_loop *loop = (event_loop *) user_data; - event_loop_stop(loop); -} - -int64_t task_table_delayed_add_task(event_loop *loop, - int64_t id, - void *context) { - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = task_table_test_fail_callback, - }; - task_table_add_task(db, Task_copy(task_table_test_task), &retry, NULL, - (void *) loop); - return EVENT_LOOP_TIMER_DONE; -} - -void task_table_test_callback(Task *callback_task, void *user_data) { - task_table_test_callback_called = 1; - RAY_CHECK(Task_state(callback_task) == TaskStatus::SCHEDULED); - RAY_CHECK(Task_size(callback_task) == Task_size(task_table_test_task)); - RAY_CHECK(Task_equals(callback_task, task_table_test_task)); - event_loop *loop = (event_loop *) user_data; - event_loop_stop(loop); -} - -TEST task_table_test(void) { - task_table_test_callback_called = 0; - event_loop *loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - DBClientID local_scheduler_id = DBClientID::from_random(); - TaskExecutionSpec spec = example_task_execution_spec(1, 1); - task_table_test_task = - Task_alloc(spec, TaskStatus::SCHEDULED, local_scheduler_id); - RetryInfo retry = { - .num_retries = NUM_RETRIES, - .timeout = TIMEOUT, - .fail_callback = task_table_test_fail_callback, - }; - task_table_subscribe(db, local_scheduler_id, TaskStatus::SCHEDULED, - task_table_test_callback, (void *) loop, &retry, NULL, - (void *) loop); - event_loop_add_timer( - loop, 200, (event_loop_timer_handler) task_table_delayed_add_task, db); - event_loop_run(loop); - Task_free(task_table_test_task); - db_disconnect(db); - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - ASSERT(task_table_test_callback_called); - PASS(); -} - -int num_test_callback_called = 0; - -void task_table_all_test_callback(Task *task, void *user_data) { - num_test_callback_called += 1; -} - -TEST task_table_all_test(void) { - event_loop *loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "local_scheduler", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - TaskExecutionSpec spec = example_task_execution_spec(1, 1); - /* Schedule two tasks on different local local schedulers. */ - Task *task1 = - Task_alloc(spec, TaskStatus::SCHEDULED, DBClientID::from_random()); - Task *task2 = - Task_alloc(spec, TaskStatus::SCHEDULED, DBClientID::from_random()); - RetryInfo retry = { - .num_retries = NUM_RETRIES, .timeout = TIMEOUT, .fail_callback = NULL, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::SCHEDULED, - task_table_all_test_callback, NULL, &retry, NULL, NULL); - event_loop_add_timer(loop, 50, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - /* TODO(pcm): Get rid of this sleep once the robust pubsub is implemented. */ - task_table_add_task(db, task1, &retry, NULL, NULL); - task_table_add_task(db, task2, &retry, NULL, NULL); - event_loop_add_timer(loop, 200, (event_loop_timer_handler) timeout_handler, - NULL); - event_loop_run(loop); - db_disconnect(db); - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); - ASSERT(num_test_callback_called == 2); - PASS(); -} - -TEST unique_client_id_test(void) { - const int num_conns = 100; - - DBClientID ids[num_conns]; - DBHandle *db; - for (int i = 0; i < num_conns; ++i) { - db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - ids[i] = get_db_client_id(db); - db_disconnect(db); - } - for (int i = 0; i < num_conns; ++i) { - for (int j = 0; j < i; ++j) { - ASSERT(!(ids[i] == ids[j])); - } - } - PASS(); -} - -SUITE(db_tests) { - RUN_REDIS_TEST(object_table_lookup_test); - RUN_REDIS_TEST(task_table_test); - RUN_REDIS_TEST(task_table_all_test); - RUN_REDIS_TEST(unique_client_id_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(db_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/example_task.h b/src/common/test/example_task.h deleted file mode 100644 index f90cab68f6d9..000000000000 --- a/src/common/test/example_task.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef EXAMPLE_TASK_H -#define EXAMPLE_TASK_H - -#include "task.h" - -extern TaskBuilder *g_task_builder; - -const int64_t arg_value_size = 1000; - -static inline TaskExecutionSpec example_task_execution_spec_with_args( - int64_t num_args, - int64_t num_returns, - ObjectID arg_ids[]) { - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskSpec_start_construct(g_task_builder, UniqueID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, num_returns); - for (int64_t i = 0; i < num_args; ++i) { - ObjectID arg_id; - if (arg_ids == NULL) { - arg_id = ObjectID::from_random(); - } else { - arg_id = arg_ids[i]; - } - TaskSpec_args_add_ref(g_task_builder, &arg_id, 1); - } - int64_t task_spec_size; - TaskSpec *spec = TaskSpec_finish_construct(g_task_builder, &task_spec_size); - std::vector execution_dependencies; - auto execution_spec = - TaskExecutionSpec(execution_dependencies, spec, task_spec_size); - TaskSpec_free(spec); - return execution_spec; -} - -static inline TaskExecutionSpec example_task_execution_spec( - int64_t num_args, - int64_t num_returns) { - return example_task_execution_spec_with_args(num_args, num_returns, NULL); -} - -static inline Task *example_task_with_args(int64_t num_args, - int64_t num_returns, - TaskStatus task_state, - ObjectID arg_ids[]) { - TaskExecutionSpec spec = - example_task_execution_spec_with_args(num_args, num_returns, arg_ids); - Task *instance = Task_alloc(spec, task_state, UniqueID::nil()); - return instance; -} - -static inline Task *example_task(int64_t num_args, - int64_t num_returns, - TaskStatus task_state) { - TaskExecutionSpec spec = example_task_execution_spec(num_args, num_returns); - Task *instance = Task_alloc(spec, task_state, UniqueID::nil()); - return instance; -} - -static inline bool Task_equals(Task *task1, Task *task2) { - if (task1->state != task2->state) { - return false; - } - if (!(task1->local_scheduler_id == task2->local_scheduler_id)) { - return false; - } - auto execution_spec1 = Task_task_execution_spec(task1); - auto execution_spec2 = Task_task_execution_spec(task2); - if (execution_spec1->SpecSize() != execution_spec2->SpecSize()) { - return false; - } - return memcmp(execution_spec1->Spec(), execution_spec2->Spec(), - execution_spec1->SpecSize()) == 0; -} - -#endif /* EXAMPLE_TASK_H */ diff --git a/src/common/test/io_tests.cc b/src/common/test/io_tests.cc deleted file mode 100644 index 092ca97b7d56..000000000000 --- a/src/common/test/io_tests.cc +++ /dev/null @@ -1,114 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include -#include - -#include "io.h" - -SUITE(io_tests); - -TEST ipc_socket_test(void) { -#ifndef _WIN32 - const char *socket_pathname = "/tmp/test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - const char *test_string = "hello world"; - const char *test_bytes = "another string"; - pid_t pid = fork(); - if (pid == 0) { - close(socket_fd); - socket_fd = connect_ipc_sock(socket_pathname); - ASSERT(socket_fd >= 0); - write_log_message(socket_fd, test_string); - write_message(socket_fd, - static_cast(CommonMessageType::LOG_MESSAGE), - strlen(test_bytes), (uint8_t *) test_bytes); - close(socket_fd); - exit(0); - } else { - int client_fd = accept_client(socket_fd); - ASSERT(client_fd >= 0); - char *message = read_log_message(client_fd); - ASSERT(message != NULL); - ASSERT_STR_EQ(test_string, message); - free(message); - int64_t type; - int64_t len; - uint8_t *bytes; - read_message(client_fd, &type, &len, &bytes); - ASSERT(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - ASSERT(memcmp(test_bytes, bytes, len) == 0); - free(bytes); - close(client_fd); - close(socket_fd); - unlink(socket_pathname); - } -#endif - PASS(); -} - -TEST long_ipc_socket_test(void) { -#ifndef _WIN32 - const char *socket_pathname = "/tmp/long-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - std::stringstream test_string_ss; - for (int i = 0; i < 10000; i++) { - test_string_ss << "hello world "; - } - std::string test_string = test_string_ss.str(); - const char *test_bytes = "another string"; - pid_t pid = fork(); - if (pid == 0) { - close(socket_fd); - socket_fd = connect_ipc_sock(socket_pathname); - ASSERT(socket_fd >= 0); - write_log_message(socket_fd, test_string.c_str()); - write_message(socket_fd, - static_cast(CommonMessageType::LOG_MESSAGE), - strlen(test_bytes), (uint8_t *) test_bytes); - close(socket_fd); - exit(0); - } else { - int client_fd = accept_client(socket_fd); - ASSERT(client_fd >= 0); - char *message = read_log_message(client_fd); - ASSERT(message != NULL); - ASSERT_STR_EQ(test_string.c_str(), message); - free(message); - int64_t type; - int64_t len; - uint8_t *bytes; - read_message(client_fd, &type, &len, &bytes); - ASSERT(static_cast(type) == - CommonMessageType::LOG_MESSAGE); - ASSERT(memcmp(test_bytes, bytes, len) == 0); - free(bytes); - close(client_fd); - close(socket_fd); - unlink(socket_pathname); - } - -#endif - PASS(); -} - -SUITE(io_tests) { - RUN_TEST(ipc_socket_test); - RUN_TEST(long_ipc_socket_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(io_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/object_table_tests.cc b/src/common/test/object_table_tests.cc deleted file mode 100644 index 059972438606..000000000000 --- a/src/common/test/object_table_tests.cc +++ /dev/null @@ -1,919 +0,0 @@ -#include "greatest.h" - -#include "event_loop.h" -#include "example_task.h" -#include "test_common.h" -#include "common.h" -#include "state/db_client_table.h" -#include "state/object_table.h" -#include "state/redis.h" - -#include - -SUITE(object_table_tests); - -static event_loop *g_loop; -TaskBuilder *g_task_builder = NULL; - -/* ==== Test adding and looking up metadata ==== */ - -int new_object_failed = 0; -int new_object_succeeded = 0; -ObjectID new_object_id; -Task *new_object_task; -TaskSpec *new_object_task_spec; -TaskID new_object_task_id; - -void new_object_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - new_object_failed = 1; - event_loop_stop(g_loop); -} - -/* === Test adding an object with an associated task === */ - -void new_object_done_callback(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context) { - new_object_succeeded = 1; - RAY_CHECK(object_id == new_object_id); - RAY_CHECK(task_id == new_object_task_id); - event_loop_stop(g_loop); -} - -void new_object_lookup_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(object_id == new_object_id); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - DBHandle *db = (DBHandle *) user_context; - result_table_lookup(db, new_object_id, &retry, new_object_done_callback, - NULL); -} - -void new_object_task_callback(TaskID task_id, void *user_context) { - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - DBHandle *db = (DBHandle *) user_context; - result_table_add(db, new_object_id, new_object_task_id, false, &retry, - new_object_lookup_callback, (void *) db); -} - -void task_table_subscribe_done(TaskID task_id, void *user_context) { - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = NULL, - }; - DBHandle *db = (DBHandle *) user_context; - task_table_add_task(db, Task_copy(new_object_task), &retry, - new_object_task_callback, db); -} - -TEST new_object_test(void) { - new_object_failed = 0; - new_object_succeeded = 0; - new_object_id = ObjectID::from_random(); - new_object_task = example_task(1, 1, TaskStatus::WAITING); - new_object_task_spec = Task_task_execution_spec(new_object_task)->Spec(); - new_object_task_id = TaskSpec_task_id(new_object_task_spec); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, task_table_subscribe_done, db); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(new_object_succeeded); - ASSERT(!new_object_failed); - PASS(); -} - -/* === Test adding an object without an associated task === */ - -void new_object_no_task_callback(ObjectID object_id, - TaskID task_id, - bool is_put, - void *user_context) { - new_object_succeeded = 1; - RAY_CHECK(task_id.is_nil()); - event_loop_stop(g_loop); -} - -TEST new_object_no_task_test(void) { - new_object_failed = 0; - new_object_succeeded = 0; - new_object_id = ObjectID::from_random(); - new_object_task_id = TaskID::from_random(); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - result_table_lookup(db, new_object_id, &retry, new_object_no_task_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(new_object_succeeded); - ASSERT(!new_object_failed); - PASS(); -} - -/* ==== Test if operations time out correctly ==== */ - -/* === Test lookup timeout === */ - -const char *lookup_timeout_context = "lookup_timeout"; -int lookup_failed = 0; - -void lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_fail_callback(UniqueID id, void *user_context, void *user_data) { - lookup_failed = 1; - RAY_CHECK(user_context == (void *) lookup_timeout_context); - event_loop_stop(g_loop); -} - -TEST lookup_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = lookup_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, lookup_done_callback, - (void *) lookup_timeout_context); - /* Disconnect the database to see if the lookup times out. */ - close(db->context->c.fd); - for (auto context : db->contexts) { - close(context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_failed); - PASS(); -} - -/* === Test add timeout === */ - -const char *add_timeout_context = "add_timeout"; -int add_failed = 0; - -void add_done_callback(ObjectID object_id, bool success, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void add_fail_callback(UniqueID id, void *user_context, void *user_data) { - add_failed = 1; - RAY_CHECK(user_context == (void *) add_timeout_context); - event_loop_stop(g_loop); -} - -TEST add_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = add_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_done_callback, (void *) add_timeout_context); - /* Disconnect the database to see if the lookup times out. */ - close(db->context->c.fd); - for (auto context : db->contexts) { - close(context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_failed); - PASS(); -} - -/* === Test subscribe timeout === */ - -int subscribe_failed = 0; - -void subscribe_done_callback(ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { - subscribe_failed = 1; - event_loop_stop(g_loop); -} - -TEST subscribe_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_fail_callback, - }; - object_table_subscribe_to_notifications(db, false, subscribe_done_callback, - NULL, &retry, NULL, NULL); - /* Disconnect the database to see if the lookup times out. */ - close(db->subscribe_context->c.fd); - for (auto subscribe_context : db->subscribe_contexts) { - close(subscribe_context->c.fd); - } - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_failed); - PASS(); -} - -/* ==== Test if the retry is working correctly ==== */ - -int64_t reconnect_context_callback(event_loop *loop, - int64_t timer_id, - void *context) { - DBHandle *db = (DBHandle *) context; - /* Reconnect to redis. This is not reconnecting the pub/sub channel. */ - redisAsyncFree(db->context); - redisFree(db->sync_context); - db->context = redisAsyncConnect("127.0.0.1", 6379); - db->context->data = (void *) db; - db->sync_context = redisConnect("127.0.0.1", 6379); - /* Re-attach the database to the event loop (the file descriptor changed). */ - db_attach(db, loop, true); - RAY_LOG(DEBUG) << "Reconnected to Redis"; - return EVENT_LOOP_TIMER_DONE; -} - -int64_t terminate_event_loop_callback(event_loop *loop, - int64_t timer_id, - void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -/* === Test lookup retry === */ - -const char *lookup_retry_context = "lookup_retry"; -int lookup_retry_succeeded = 0; - -void lookup_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -/* === Test add retry === */ - -const char *add_retry_context = "add_retry"; -int add_retry_succeeded = 0; - -/* === Test add then lookup retry === */ - -void add_lookup_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_ids, - void *context) { - DBHandle *db = (DBHandle *) context; - RAY_CHECK(manager_ids.size() == 1); - const std::vector managers = - db_client_table_get_ip_addresses(db, manager_ids); - RAY_CHECK(managers.at(0) == "127.0.0.1:11235"); - lookup_retry_succeeded = 1; -} - -void add_lookup_callback(ObjectID object_id, bool success, void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, add_lookup_done_callback, - (void *) db); -} - -TEST add_lookup_test(void) { - g_loop = event_loop_create(); - lookup_retry_succeeded = 0; - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11235"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, true); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_lookup_callback, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_retry_succeeded); - PASS(); -} - -/* === Test add, remove, then lookup === */ -void add_remove_lookup_done_callback( - ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - RAY_CHECK(context == (void *) lookup_retry_context); - RAY_CHECK(manager_vector.size() == 0); - lookup_retry_succeeded = 1; -} - -void add_remove_lookup_callback(ObjectID object_id, - bool success, - void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, - add_remove_lookup_done_callback, - (void *) lookup_retry_context); -} - -void add_remove_callback(ObjectID object_id, bool success, void *user_context) { - RAY_CHECK(success); - DBHandle *db = (DBHandle *) user_context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_remove(db, UniqueID::nil(), NULL, &retry, - add_remove_lookup_callback, (void *) db); -} - -TEST add_remove_lookup_test(void) { - g_loop = event_loop_create(); - lookup_retry_succeeded = 0; - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, true); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = lookup_retry_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_remove_callback, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_retry_succeeded); - PASS(); -} - -/* ==== Test if late succeed is working correctly ==== */ - -/* === Test lookup late succeed === */ - -const char *lookup_late_context = "lookup_late"; -int lookup_late_failed = 0; - -void lookup_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) lookup_late_context); - lookup_late_failed = 1; -} - -void lookup_late_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST lookup_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = lookup_late_fail_callback, - }; - object_table_lookup(db, UniqueID::nil(), &retry, lookup_late_done_callback, - (void *) lookup_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_late_failed); - PASS(); -} - -/* === Test add late succeed === */ - -const char *add_late_context = "add_late"; -int add_late_failed = 0; - -void add_late_fail_callback(UniqueID id, void *user_context, void *user_data) { - RAY_CHECK(user_context == (void *) add_late_context); - add_late_failed = 1; -} - -void add_late_done_callback(ObjectID object_id, - bool success, - void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST add_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, .timeout = 0, .fail_callback = add_late_fail_callback, - }; - object_table_add(db, UniqueID::nil(), 0, (unsigned char *) NIL_DIGEST, &retry, - add_late_done_callback, (void *) add_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_late_failed); - PASS(); -} - -/* === Test subscribe late succeed === */ - -const char *subscribe_late_context = "subscribe_late"; -int subscribe_late_failed = 0; - -void subscribe_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) subscribe_late_context); - subscribe_late_failed = 1; -} - -void subscribe_late_done_callback(ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = subscribe_late_fail_callback, - }; - object_table_subscribe_to_notifications(db, false, NULL, NULL, &retry, - subscribe_late_done_callback, - (void *) subscribe_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_late_failed); - PASS(); -} - -/* === Test subscribe object available succeed === */ - -const char *subscribe_success_context = "subscribe_success"; -int subscribe_success_done = 0; -int subscribe_success_succeeded = 0; -ObjectID subscribe_id; - -void subscribe_success_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -void subscribe_success_done_callback( - ObjectID object_id, - bool never_created, - const std::vector &manager_vector, - void *user_context) { - RetryInfo retry = { - .num_retries = 0, .timeout = 750, .fail_callback = NULL, - }; - object_table_add((DBHandle *) user_context, subscribe_id, 0, - (unsigned char *) NIL_DIGEST, &retry, NULL, NULL); - subscribe_success_done = 1; -} - -void subscribe_success_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - RAY_CHECK(user_context == (void *) subscribe_success_context); - RAY_CHECK(object_id == subscribe_id); - RAY_CHECK(manager_vector.size() == 1); - subscribe_success_succeeded = 1; -} - -TEST subscribe_success_test(void) { - g_loop = event_loop_create(); - - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - subscribe_id = ObjectID::from_random(); - - RetryInfo retry = { - .num_retries = 0, - .timeout = 100, - .fail_callback = subscribe_success_fail_callback, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_success_object_available_callback, - (void *) subscribe_success_context, &retry, - subscribe_success_done_callback, (void *) db); - - ObjectID object_ids[1] = {subscribe_id}; - object_table_request_notifications(db, 1, object_ids, &retry); - - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - - ASSERT(subscribe_success_done); - ASSERT(subscribe_success_succeeded); - PASS(); -} - -/* Test if subscribe succeeds if the object is already present. */ -typedef struct { - const char *teststr; - int64_t data_size; -} subscribe_object_present_context_t; - -const char *subscribe_object_present_str = "subscribe_object_present"; -int subscribe_object_present_succeeded = 0; - -void subscribe_object_present_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - subscribe_object_present_context_t *ctx = - (subscribe_object_present_context_t *) user_context; - RAY_CHECK(ctx->data_size == data_size); - RAY_CHECK(strcmp(subscribe_object_present_str, ctx->teststr) == 0); - subscribe_object_present_succeeded = 1; - RAY_CHECK(manager_vector.size() == 1); -} - -void fatal_fail_callback(UniqueID id, void *user_context, void *user_data) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_object_present_test(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t myctx = {subscribe_object_present_str, - data_size}; - - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = fatal_fail_callback, - }; - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - object_table_subscribe_to_notifications( - db, false, subscribe_object_present_object_available_callback, - (void *) &myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to create do the add and subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_object_present_succeeded == 1); - PASS(); -} - -/* Test if subscribe is not called if object is not present. */ - -const char *subscribe_object_not_present_context = - "subscribe_object_not_present"; - -void subscribe_object_not_present_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - /* This should not be called. */ - RAY_CHECK(0); -} - -TEST subscribe_object_not_present_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_object_not_present_object_available_callback, - (void *) subscribe_object_not_present_context, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - PASS(); -} - -/* Test if subscribe is called if object becomes available later. */ - -const char *subscribe_object_available_later_context = - "subscribe_object_available_later"; -int subscribe_object_available_later_succeeded = 0; - -void subscribe_object_available_later_object_available_callback( - ObjectID object_id, - int64_t data_size, - const std::vector &manager_vector, - void *user_context) { - subscribe_object_present_context_t *myctx = - (subscribe_object_present_context_t *) user_context; - RAY_CHECK(myctx->data_size == data_size); - RAY_CHECK(strcmp(myctx->teststr, subscribe_object_available_later_context) == - 0); - /* Make sure the callback is only called once. */ - subscribe_object_available_later_succeeded += 1; - RAY_CHECK(manager_vector.size() == 1); -} - -TEST subscribe_object_available_later_test(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t *myctx = - (subscribe_object_present_context_t *) malloc( - sizeof(subscribe_object_present_context_t)); - myctx->teststr = subscribe_object_available_later_context; - myctx->data_size = data_size; - - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, false, subscribe_object_available_later_object_available_callback, - (void *) myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - ObjectID object_ids[1] = {id}; - object_table_request_notifications(db, 1, object_ids, &retry); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the request notifications. */ - event_loop_run(g_loop); - - ASSERT_EQ(subscribe_object_available_later_succeeded, 0); - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the object table add. */ - event_loop_run(g_loop); - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT_EQ(subscribe_object_available_later_succeeded, 1); - /* Reset the global variable before exiting this unit test. */ - subscribe_object_available_later_succeeded = 0; - free(myctx); - PASS(); -} - -TEST subscribe_object_available_subscribe_all(void) { - int64_t data_size = 0xF1F0; - subscribe_object_present_context_t myctx = { - subscribe_object_available_later_context, data_size}; - g_loop = event_loop_create(); - /* Construct the arguments to db_connect. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:11236"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, g_loop, false); - UniqueID id = UniqueID::from_random(); - RetryInfo retry = { - .num_retries = 0, .timeout = 100, .fail_callback = NULL, - }; - object_table_subscribe_to_notifications( - db, true, subscribe_object_available_later_object_available_callback, - (void *) &myctx, &retry, NULL, (void *) db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the subscribe. */ - event_loop_run(g_loop); - - /* At this point we don't expect any object notifications received. */ - ASSERT_EQ(subscribe_object_available_later_succeeded, 0); - object_table_add(db, id, data_size, (unsigned char *) NIL_DIGEST, &retry, - NULL, NULL); - /* Install handler to terminate event loop after 750ms. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* Run the event loop to do the object table add. */ - event_loop_run(g_loop); - /* At this point we assume that object table add completed. */ - - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - /* Assert that the object table add completed and notification callback fired. - */ - printf("subscribe_all object info test: callback fired: %d times\n", - subscribe_object_available_later_succeeded); - fflush(stdout); - ASSERT_EQ(subscribe_object_available_later_succeeded, 1); - /* Reset the global variable before exiting this unit test. */ - subscribe_object_available_later_succeeded = 0; - PASS(); -} - -SUITE(object_table_tests) { - RUN_REDIS_TEST(new_object_test); - RUN_REDIS_TEST(new_object_no_task_test); - // RUN_REDIS_TEST(lookup_timeout_test); - // RUN_REDIS_TEST(add_timeout_test); - // RUN_REDIS_TEST(subscribe_timeout_test); - RUN_REDIS_TEST(add_lookup_test); - RUN_REDIS_TEST(add_remove_lookup_test); - // RUN_REDIS_TEST(lookup_late_test); - // RUN_REDIS_TEST(add_late_test); - // RUN_REDIS_TEST(subscribe_late_test); - RUN_REDIS_TEST(subscribe_success_test); - RUN_REDIS_TEST(subscribe_object_not_present_test); - RUN_REDIS_TEST(subscribe_object_available_later_test); - RUN_REDIS_TEST(subscribe_object_available_subscribe_all); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(object_table_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/redis_tests.cc b/src/common/test/redis_tests.cc deleted file mode 100644 index 7db7ae2ee26e..000000000000 --- a/src/common/test/redis_tests.cc +++ /dev/null @@ -1,238 +0,0 @@ -#include "greatest.h" - -#include -#include - -#include - -#include "event_loop.h" -#include "state/db.h" -#include "state/redis.h" -#include "io.h" -#include "logging.h" -#include "test_common.h" - -SUITE(redis_tests); - -const char *test_set_format = "SET %s %s"; -const char *test_get_format = "GET %s"; -const char *test_key = "foo"; -const char *test_value = "bar"; -std::vector connections; - -void write_formatted_log_message(int socket_fd, const char *format, ...) { - va_list ap; - - /* Get cmd size */ - va_start(ap, format); - size_t cmd_size = vsnprintf(nullptr, 0, format, ap) + 1; - va_end(ap); - - /* Print va to cmd */ - char cmd[cmd_size]; - va_start(ap, format); - vsnprintf(cmd, cmd_size, format, ap); - va_end(ap); - - write_log_message(socket_fd, cmd); -} - -int async_redis_socket_test_callback_called = 0; - -void async_redis_socket_test_callback(redisAsyncContext *ac, - void *r, - void *privdata) { - async_redis_socket_test_callback_called = 1; - redisContext *context = redisConnect("127.0.0.1", 6379); - redisReply *reply = - (redisReply *) redisCommand(context, test_get_format, test_key); - redisFree(context); - RAY_CHECK(reply != NULL); - if (strcmp(reply->str, test_value)) { - freeReplyObject(reply); - RAY_CHECK(0); - } - freeReplyObject(reply); -} - -TEST redis_socket_test(void) { - const char *socket_pathname = "/tmp/redis-test-socket"; - redisContext *context = redisConnect("127.0.0.1", 6379); - ASSERT(context != NULL); - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - write_formatted_log_message(client_fd, test_set_format, test_key, test_value); - - int server_fd = accept_client(socket_fd); - char *cmd = read_log_message(server_fd); - close(client_fd); - close(server_fd); - close(socket_fd); - unlink(socket_pathname); - - redisReply *reply = (redisReply *) redisCommand(context, cmd, 0, 0); - freeReplyObject(reply); - reply = (redisReply *) redisCommand(context, "GET %s", test_key); - ASSERT(reply != NULL); - ASSERT_STR_EQ(reply->str, test_value); - freeReplyObject(reply); - - free(cmd); - redisFree(context); - PASS(); -} - -void redis_read_callback(event_loop *loop, int fd, void *context, int events) { - DBHandle *db = (DBHandle *) context; - char *cmd = read_log_message(fd); - redisAsyncCommand(db->context, async_redis_socket_test_callback, NULL, cmd); - free(cmd); -} - -void redis_accept_callback(event_loop *loop, - int socket_fd, - void *context, - int events) { - int accept_fd = accept_client(socket_fd); - RAY_CHECK(accept_fd >= 0); - connections.push_back(accept_fd); - event_loop_add_file(loop, accept_fd, EVENT_LOOP_READ, redis_read_callback, - context); -} - -int timeout_handler(event_loop *loop, timer_id timer_id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -TEST async_redis_socket_test(void) { - event_loop *loop = event_loop_create(); - - /* Start IPC channel. */ - const char *socket_pathname = "/tmp/async-redis-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - connections.push_back(socket_fd); - - /* Start connection to Redis. */ - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "test_process", - "127.0.0.1", std::vector()); - db_attach(db, loop, false); - - /* Send a command to the Redis process. */ - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - connections.push_back(client_fd); - write_formatted_log_message(client_fd, test_set_format, test_key, test_value); - - event_loop_add_file(loop, client_fd, EVENT_LOOP_READ, redis_read_callback, - db); - event_loop_add_file(loop, socket_fd, EVENT_LOOP_READ, redis_accept_callback, - db); - event_loop_add_timer(loop, 100, timeout_handler, NULL); - event_loop_run(loop); - - ASSERT(async_redis_socket_test_callback_called); - - db_disconnect(db); - event_loop_destroy(loop); - - for (int const &p : connections) { - close(p); - } - unlink(socket_pathname); - connections.clear(); - PASS(); -} - -int logging_test_callback_called = 0; - -void logging_test_callback(redisAsyncContext *ac, void *r, void *privdata) { - logging_test_callback_called = 1; - redisContext *context = redisConnect("127.0.0.1", 6379); - redisReply *reply = (redisReply *) redisCommand(context, "KEYS %s", "log:*"); - redisFree(context); - RAY_CHECK(reply != NULL); - RAY_CHECK(reply->elements > 0); - freeReplyObject(reply); -} - -void logging_read_callback(event_loop *loop, - int fd, - void *context, - int events) { - DBHandle *conn = (DBHandle *) context; - char *cmd = read_log_message(fd); - redisAsyncCommand(conn->context, logging_test_callback, NULL, cmd, - (char *) conn->client.data(), sizeof(conn->client)); - free(cmd); -} - -void logging_accept_callback(event_loop *loop, - int socket_fd, - void *context, - int events) { - int accept_fd = accept_client(socket_fd); - RAY_CHECK(accept_fd >= 0); - connections.push_back(accept_fd); - event_loop_add_file(loop, accept_fd, EVENT_LOOP_READ, logging_read_callback, - context); -} - -TEST logging_test(void) { - event_loop *loop = event_loop_create(); - - /* Start IPC channel. */ - const char *socket_pathname = "/tmp/logging-test-socket"; - int socket_fd = bind_ipc_sock(socket_pathname, true); - ASSERT(socket_fd >= 0); - connections.push_back(socket_fd); - - /* Start connection to Redis. */ - DBHandle *conn = db_connect(std::string("127.0.0.1"), 6379, "test_process", - "127.0.0.1", std::vector()); - db_attach(conn, loop, false); - - /* Send a command to the Redis process. */ - int client_fd = connect_ipc_sock(socket_pathname); - ASSERT(client_fd >= 0); - connections.push_back(client_fd); - RayLogger *logger = RayLogger_init("worker", RAY_LOG_INFO, 0, &client_fd); - RayLogger_log(logger, RAY_LOG_INFO, "TEST", "Message"); - - event_loop_add_file(loop, socket_fd, EVENT_LOOP_READ, logging_accept_callback, - conn); - event_loop_add_file(loop, client_fd, EVENT_LOOP_READ, logging_read_callback, - conn); - event_loop_add_timer(loop, 100, timeout_handler, NULL); - event_loop_run(loop); - - ASSERT(logging_test_callback_called); - - RayLogger_free(logger); - db_disconnect(conn); - event_loop_destroy(loop); - for (int const &p : connections) { - close(p); - } - unlink(socket_pathname); - connections.clear(); - PASS(); -} - -SUITE(redis_tests) { - RUN_REDIS_TEST(redis_socket_test); - RUN_REDIS_TEST(async_redis_socket_test); - RUN_REDIS_TEST(logging_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(redis_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/run_tests.sh b/src/common/test/run_tests.sh deleted file mode 100644 index 5ccb1e3f92ff..000000000000 --- a/src/common/test/run_tests.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -# Cause the script to exit if a single command fails. -set -ex - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi - sleep 1s -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -if [ -z "$RAY_USE_NEW_GCS" ]; then - ./src/common/db_tests - ./src/common/io_tests - ./src/common/task_tests - ./src/common/redis_tests - ./src/common/task_table_tests - ./src/common/object_table_tests -fi - -./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown diff --git a/src/common/test/run_valgrind.sh b/src/common/test/run_valgrind.sh deleted file mode 100644 index 418a91366e13..000000000000 --- a/src/common/test/run_valgrind.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -set -x - -# Cause the script to exit if a single command fails. -set -e - -if [ -z "$RAY_USE_NEW_GCS" ]; then - # Start the Redis shards. - ./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & - ./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6380 & - sleep 1s - # Register the shard location with the primary shard. - ./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 - ./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/db_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/io_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/task_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/redis_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/task_table_tests - valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/common/object_table_tests - ./src/common/thirdparty/redis/src/redis-cli shutdown - ./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -fi diff --git a/src/common/test/task_table_tests.cc b/src/common/test/task_table_tests.cc deleted file mode 100644 index f94aca3b132c..000000000000 --- a/src/common/test/task_table_tests.cc +++ /dev/null @@ -1,460 +0,0 @@ -#include "greatest.h" - -#include "event_loop.h" -#include "example_task.h" -#include "test_common.h" -#include "common.h" -#include "state/object_table.h" -#include "state/redis.h" - -#include -#include - -SUITE(task_table_tests); - -event_loop *g_loop; -TaskBuilder *g_task_builder = NULL; - -/* ==== Test operations in non-failure scenario ==== */ - -/* === A lookup of a task not in the table === */ - -TaskID lookup_nil_id; -int lookup_nil_success = 0; -const char *lookup_nil_context = "lookup_nil"; - -void lookup_nil_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_nil_success_callback(Task *task, void *context) { - lookup_nil_success = 1; - RAY_CHECK(task == NULL); - RAY_CHECK(context == (void *) lookup_nil_context); - event_loop_stop(g_loop); -} - -TEST lookup_nil_test(void) { - lookup_nil_id = TaskID::from_random(); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = lookup_nil_fail_callback, - }; - task_table_get_task(db, lookup_nil_id, &retry, lookup_nil_success_callback, - (void *) lookup_nil_context); - /* Disconnect the database to see if the lookup times out. */ - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(lookup_nil_success); - PASS(); -} - -/* === A lookup of a task after it's added returns the same spec === */ - -int add_success = 0; -int lookup_success = 0; -Task *add_lookup_task; -const char *add_lookup_context = "add_lookup"; - -void add_lookup_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -void lookup_success_callback(Task *task, void *context) { - lookup_success = 1; - RAY_CHECK(Task_equals(task, add_lookup_task)); - event_loop_stop(g_loop); -} - -void add_success_callback(TaskID task_id, void *context) { - add_success = 1; - RAY_CHECK(TaskID_equal(task_id, Task_task_id(add_lookup_task))); - - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - task_table_get_task(db, task_id, &retry, lookup_success_callback, - (void *) add_lookup_context); -} - -void subscribe_success_callback(TaskID task_id, void *context) { - DBHandle *db = (DBHandle *) context; - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - task_table_add_task(db, Task_copy(add_lookup_task), &retry, - add_success_callback, (void *) db); -} - -TEST add_lookup_test(void) { - add_lookup_task = example_task(1, 1, TaskStatus::WAITING); - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 1000, - .fail_callback = add_lookup_fail_callback, - }; - /* Wait for subscription to succeed before adding the task. */ - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_success_callback, (void *) db); - /* Disconnect the database to see if the lookup times out. */ - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(add_success); - ASSERT(lookup_success); - PASS(); -} - -/* ==== Test if operations time out correctly ==== */ - -/* === Test subscribe timeout === */ - -const char *subscribe_timeout_context = "subscribe_timeout"; -int subscribe_failed = 0; - -void subscribe_done_callback(TaskID task_id, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void subscribe_fail_callback(UniqueID id, void *user_context, void *user_data) { - subscribe_failed = 1; - RAY_CHECK(user_context == (void *) subscribe_timeout_context); - event_loop_stop(g_loop); -} - -TEST subscribe_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_done_callback, - (void *) subscribe_timeout_context); - /* Disconnect the database to see if the subscribe times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_failed); - PASS(); -} - -/* === Test publish timeout === */ - -const char *publish_timeout_context = "publish_timeout"; -int publish_failed = 0; - -void publish_done_callback(TaskID task_id, void *user_context) { - /* The done callback should not be called. */ - RAY_CHECK(0); -} - -void publish_fail_callback(UniqueID id, void *user_context, void *user_data) { - publish_failed = 1; - RAY_CHECK(user_context == (void *) publish_timeout_context); - event_loop_stop(g_loop); -} - -TEST publish_timeout_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 5, .timeout = 100, .fail_callback = publish_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, NULL, NULL); - task_table_add_task(db, task, &retry, publish_done_callback, - (void *) publish_timeout_context); - /* Disconnect the database to see if the publish times out. */ - close(db->context->c.fd); - for (size_t i = 0; i < db->contexts.size(); ++i) { - close(db->contexts[i]->c.fd); - } - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_failed); - PASS(); -} - -/* ==== Test if the retry is working correctly ==== */ - -int64_t reconnect_db_callback(event_loop *loop, - int64_t timer_id, - void *context) { - DBHandle *db = (DBHandle *) context; - /* Reconnect to redis. */ - redisAsyncFree(db->subscribe_context); - db->subscribe_context = redisAsyncConnect("127.0.0.1", 6379); - db->subscribe_context->data = (void *) db; - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - redisAsyncFree(db->subscribe_contexts[i]); - db->subscribe_contexts[i] = redisAsyncConnect("127.0.0.1", 6380 + i); - db->subscribe_contexts[i]->data = (void *) db; - } - /* Re-attach the database to the event loop (the file descriptor changed). */ - db_attach(db, loop, true); - return EVENT_LOOP_TIMER_DONE; -} - -int64_t terminate_event_loop_callback(event_loop *loop, - int64_t timer_id, - void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -/* === Test subscribe retry === */ - -const char *subscribe_retry_context = "subscribe_retry"; -int subscribe_retry_succeeded = 0; - -void subscribe_retry_done_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(user_context == (void *) subscribe_retry_context); - subscribe_retry_succeeded = 1; -} - -void subscribe_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -TEST subscribe_retry_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = subscribe_retry_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_retry_done_callback, - (void *) subscribe_retry_context); - /* Disconnect the database to see if the subscribe times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - /* Install handler for reconnecting the database. */ - event_loop_add_timer(g_loop, 150, - (event_loop_timer_handler) reconnect_db_callback, db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_retry_succeeded); - PASS(); -} - -/* === Test publish retry === */ - -const char *publish_retry_context = "publish_retry"; -int publish_retry_succeeded = 0; - -void publish_retry_done_callback(ObjectID object_id, void *user_context) { - RAY_CHECK(user_context == (void *) publish_retry_context); - publish_retry_succeeded = 1; -} - -void publish_retry_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - /* The fail callback should not be called. */ - RAY_CHECK(0); -} - -TEST publish_retry_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = publish_retry_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, NULL, NULL); - task_table_add_task(db, task, &retry, publish_retry_done_callback, - (void *) publish_retry_context); - /* Disconnect the database to see if the publish times out. */ - close(db->subscribe_context->c.fd); - for (size_t i = 0; i < db->subscribe_contexts.size(); ++i) { - close(db->subscribe_contexts[i]->c.fd); - } - /* Install handler for reconnecting the database. */ - event_loop_add_timer(g_loop, 150, - (event_loop_timer_handler) reconnect_db_callback, db); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_retry_succeeded); - PASS(); -} - -/* ==== Test if late succeed is working correctly ==== */ - -/* === Test subscribe late succeed === */ - -const char *subscribe_late_context = "subscribe_late"; -int subscribe_late_failed = 0; - -void subscribe_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) subscribe_late_context); - subscribe_late_failed = 1; -} - -void subscribe_late_done_callback(TaskID task_id, void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST subscribe_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = subscribe_late_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - &retry, subscribe_late_done_callback, - (void *) subscribe_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(subscribe_late_failed); - PASS(); -} - -/* === Test publish late succeed === */ - -const char *publish_late_context = "publish_late"; -int publish_late_failed = 0; - -void publish_late_fail_callback(UniqueID id, - void *user_context, - void *user_data) { - RAY_CHECK(user_context == (void *) publish_late_context); - publish_late_failed = 1; -} - -void publish_late_done_callback(TaskID task_id, void *user_context) { - /* This function should never be called. */ - RAY_CHECK(0); -} - -TEST publish_late_test(void) { - g_loop = event_loop_create(); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", std::vector()); - db_attach(db, g_loop, false); - Task *task = example_task(1, 1, TaskStatus::WAITING); - RetryInfo retry = { - .num_retries = 0, - .timeout = 0, - .fail_callback = publish_late_fail_callback, - }; - task_table_subscribe(db, UniqueID::nil(), TaskStatus::WAITING, NULL, NULL, - NULL, NULL, NULL); - task_table_add_task(db, task, &retry, publish_late_done_callback, - (void *) publish_late_context); - /* Install handler for terminating the event loop. */ - event_loop_add_timer(g_loop, 750, - (event_loop_timer_handler) terminate_event_loop_callback, - NULL); - /* First process timer events to make sure the timeout is processed before - * anything else. */ - aeProcessEvents(g_loop, AE_TIME_EVENTS); - event_loop_run(g_loop); - db_disconnect(db); - destroy_outstanding_callbacks(g_loop); - event_loop_destroy(g_loop); - ASSERT(publish_late_failed); - PASS(); -} - -SUITE(task_table_tests) { - RUN_REDIS_TEST(lookup_nil_test); - RUN_REDIS_TEST(add_lookup_test); - // RUN_REDIS_TEST(subscribe_timeout_test); - // RUN_REDIS_TEST(publish_timeout_test); - // RUN_REDIS_TEST(subscribe_retry_test); - // RUN_REDIS_TEST(publish_retry_test); - // RUN_REDIS_TEST(subscribe_late_test); - // RUN_REDIS_TEST(publish_late_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(task_table_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/task_tests.cc b/src/common/test/task_tests.cc deleted file mode 100644 index 2277912e7dec..000000000000 --- a/src/common/test/task_tests.cc +++ /dev/null @@ -1,212 +0,0 @@ -#include "greatest.h" - -#include -#include -#include - -#include "common.h" -#include "test_common.h" -#include "task.h" -#include "io.h" - -SUITE(task_tests); - -TEST task_test(void) { - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskBuilder *builder = make_task_builder(); - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 2); - - UniqueID arg1 = UniqueID::from_random(); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "hello", 5); - UniqueID arg2 = UniqueID::from_random(); - TaskSpec_args_add_ref(builder, &arg2, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "world", 5); - /* Finish constructing the spec. This constructs the task ID and the - * return IDs. */ - int64_t size; - TaskSpec *spec = TaskSpec_finish_construct(builder, &size); - - /* Check that the spec was constructed as expected. */ - ASSERT(TaskSpec_num_args(spec) == 4); - ASSERT(TaskSpec_num_returns(spec) == 2); - ASSERT(FunctionID_equal(TaskSpec_function(spec), func_id)); - ASSERT(TaskSpec_arg_id(spec, 0, 0) == arg1); - ASSERT(memcmp(TaskSpec_arg_val(spec, 1), (uint8_t *) "hello", - TaskSpec_arg_length(spec, 1)) == 0); - ASSERT(TaskSpec_arg_id(spec, 2, 0) == arg2); - ASSERT(memcmp(TaskSpec_arg_val(spec, 3), (uint8_t *) "world", - TaskSpec_arg_length(spec, 3)) == 0); - - TaskSpec_free(spec); - free_task_builder(builder); - PASS(); -} - -TEST deterministic_ids_test(void) { - TaskBuilder *builder = make_task_builder(); - /* Define the inputs to the task construction. */ - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - UniqueID arg1 = UniqueID::from_random(); - uint8_t *arg2 = (uint8_t *) "hello world"; - - /* Construct a first task. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size1; - TaskSpec *spec1 = TaskSpec_finish_construct(builder, &size1); - - /* Construct a second identical task. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size2; - TaskSpec *spec2 = TaskSpec_finish_construct(builder, &size2); - - /* Check that these tasks have the same task IDs and the same return IDs. */ - ASSERT(TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec2))); - ASSERT(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 0)); - ASSERT(TaskSpec_return(spec1, 1) == TaskSpec_return(spec2, 1)); - ASSERT(TaskSpec_return(spec1, 2) == TaskSpec_return(spec2, 2)); - /* Check that the return IDs are all distinct. */ - ASSERT(!(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 1))); - ASSERT(!(TaskSpec_return(spec1, 0) == TaskSpec_return(spec2, 2))); - ASSERT(!(TaskSpec_return(spec1, 1) == TaskSpec_return(spec2, 2))); - - /* Create more tasks that are only mildly different. */ - - /* Construct a task with a different parent task ID. */ - TaskSpec_start_construct(builder, DriverID::nil(), TaskID::from_random(), 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size3; - TaskSpec *spec3 = TaskSpec_finish_construct(builder, &size3); - - /* Construct a task with a different parent counter. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 1, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size4; - TaskSpec *spec4 = TaskSpec_finish_construct(builder, &size4); - - /* Construct a task with a different function ID. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, FunctionID::from_random(), - 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size5; - TaskSpec *spec5 = TaskSpec_finish_construct(builder, &size5); - - /* Construct a task with a different object ID argument. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - ObjectID object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - TaskSpec_args_add_val(builder, arg2, 11); - int64_t size6; - TaskSpec *spec6 = TaskSpec_finish_construct(builder, &size6); - - /* Construct a task with a different value argument. */ - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 3); - TaskSpec_args_add_ref(builder, &arg1, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "hello_world", 11); - int64_t size7; - TaskSpec *spec7 = TaskSpec_finish_construct(builder, &size7); - - /* Check that the task IDs are all distinct from the original. */ - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec3))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec4))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec5))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec6))); - ASSERT(!TaskID_equal(TaskSpec_task_id(spec1), TaskSpec_task_id(spec7))); - - /* Check that the return object IDs are distinct from the originals. */ - TaskSpec *specs[6] = {spec1, spec3, spec4, spec5, spec6, spec7}; - for (int task_index1 = 0; task_index1 < 6; ++task_index1) { - for (int return_index1 = 0; return_index1 < 3; ++return_index1) { - for (int task_index2 = 0; task_index2 < 6; ++task_index2) { - for (int return_index2 = 0; return_index2 < 3; ++return_index2) { - if (task_index1 != task_index2 && return_index1 != return_index2) { - ASSERT(!(TaskSpec_return(specs[task_index1], return_index1) == - TaskSpec_return(specs[task_index2], return_index2))); - } - } - } - } - } - - TaskSpec_free(spec1); - TaskSpec_free(spec2); - TaskSpec_free(spec3); - TaskSpec_free(spec4); - TaskSpec_free(spec5); - TaskSpec_free(spec6); - TaskSpec_free(spec7); - free_task_builder(builder); - PASS(); -} - -TEST send_task(void) { - TaskBuilder *builder = make_task_builder(); - TaskID parent_task_id = TaskID::from_random(); - FunctionID func_id = FunctionID::from_random(); - TaskSpec_start_construct(builder, DriverID::nil(), parent_task_id, 0, - ActorID::nil(), ObjectID::nil(), ActorID::nil(), - ActorID::nil(), 0, false, func_id, 2); - ObjectID object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - TaskSpec_args_add_val(builder, (uint8_t *) "Hello", 5); - TaskSpec_args_add_val(builder, (uint8_t *) "World", 5); - object_id = ObjectID::from_random(); - TaskSpec_args_add_ref(builder, &object_id, 1); - int64_t size; - TaskSpec *spec = TaskSpec_finish_construct(builder, &size); - int fd[2]; - socketpair(AF_UNIX, SOCK_STREAM, 0, fd); - write_message(fd[0], static_cast(CommonMessageType::SUBMIT_TASK), - size, (uint8_t *) spec); - int64_t type; - int64_t length; - uint8_t *message; - read_message(fd[1], &type, &length, &message); - TaskSpec *result = (TaskSpec *) message; - ASSERT(static_cast(type) == - CommonMessageType::SUBMIT_TASK); - ASSERT(memcmp(spec, result, size) == 0); - TaskSpec_free(spec); - free(result); - free_task_builder(builder); - PASS(); -} - -SUITE(task_tests) { - RUN_TEST(task_test); - RUN_TEST(deterministic_ids_test); - RUN_TEST(send_task); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); - RUN_SUITE(task_tests); - GREATEST_MAIN_END(); -} diff --git a/src/common/test/test_common.h b/src/common/test/test_common.h deleted file mode 100644 index 03984e6f2249..000000000000 --- a/src/common/test/test_common.h +++ /dev/null @@ -1,91 +0,0 @@ -#ifndef TEST_COMMON_H -#define TEST_COMMON_H - -#include - -#include -#include -#include - -#include "common.h" -#include "io.h" -#include "hiredis/hiredis.h" -#include "state/redis.h" - -#ifndef _WIN32 -/* This function is actually not declared in standard POSIX, so declare it. */ -extern int usleep(useconds_t usec); -#endif - -/* I/O helper methods to retry binding to sockets. */ -static inline std::string bind_ipc_sock_retry(const char *socket_name_format, - int *fd) { - std::string socket_name; - for (int num_retries = 0; num_retries < 5; ++num_retries) { - RAY_LOG(INFO) << "trying to find plasma socket (attempt " << num_retries - << ")"; - size_t size = std::snprintf(nullptr, 0, socket_name_format, rand()) + 1; - char socket_name_c_str[size]; - std::snprintf(socket_name_c_str, size, socket_name_format, rand()); - socket_name = std::string(socket_name_c_str); - - *fd = bind_ipc_sock(socket_name.c_str(), true); - if (*fd < 0) { - /* Sleep for 100ms. */ - usleep(100000); - continue; - } - break; - } - return socket_name; -} - -static inline int bind_inet_sock_retry(int *fd) { - int port = -1; - for (int num_retries = 0; num_retries < 5; ++num_retries) { - port = 10000 + rand() % 40000; - *fd = bind_inet_sock(port, true); - if (*fd < 0) { - /* Sleep for 100ms. */ - usleep(100000); - continue; - } - break; - } - return port; -} - -/* Flush redis. */ -static inline void flushall_redis(void) { - /* Flush the primary shard. */ - redisContext *context = redisConnect("127.0.0.1", 6379); - std::vector db_shards_addresses; - std::vector db_shards_ports; - get_redis_shards(context, db_shards_addresses, db_shards_ports); - freeReplyObject(redisCommand(context, "FLUSHALL")); - /* Readd the shard locations. */ - freeReplyObject(redisCommand(context, "SET NumRedisShards %d", - db_shards_addresses.size())); - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - freeReplyObject(redisCommand(context, "RPUSH RedisShards %s:%d", - db_shards_addresses[i].c_str(), - db_shards_ports[i])); - } - redisFree(context); - - /* Flush the remaining shards. */ - for (size_t i = 0; i < db_shards_addresses.size(); ++i) { - context = redisConnect(db_shards_addresses[i].c_str(), db_shards_ports[i]); - freeReplyObject(redisCommand(context, "FLUSHALL")); - redisFree(context); - } -} - -/* Cleanup method for running tests with the greatest library. - * Runs the test, then clears the Redis database. */ -#define RUN_REDIS_TEST(test) \ - flushall_redis(); \ - RUN_TEST(test); \ - flushall_redis(); - -#endif /* TEST_COMMON */ diff --git a/src/common/thirdparty/download_thirdparty.bat b/src/common/thirdparty/download_thirdparty.bat deleted file mode 100644 index 988592f83af6..000000000000 --- a/src/common/thirdparty/download_thirdparty.bat +++ /dev/null @@ -1,15 +0,0 @@ -@SetLocal - @Echo Off - @PushD "%~dp0" - git submodule update --init --jobs="%NUMBER_OF_PROCESSORS%" - @If Not Exist "python\.git" git clone "https://github.com/austinsc/python.git" - Call :GitApply "python" "%CD%/patches/windows/python-pyconfig.patch" - Call :GitApply "redis-windows" "%CD%/patches/windows/redis.patch" - @PopD -@EndLocal -@GoTo :EOF - -:GitApply - @REM Check if patch already applied by attempting to apply it in reverse; if not, then force-reapply it - git -C "%~1" apply "%~2" -R --check 2> NUL || git -C "%~1" apply "%~2" --3way 2> NUL || git -C "%~1" reset --hard && git -C "%~1" apply "%~2" --3way -@GoTo :EOF diff --git a/src/common/thirdparty/greatest.h b/src/common/thirdparty/greatest.h deleted file mode 100644 index eb34ff4263ec..000000000000 --- a/src/common/thirdparty/greatest.h +++ /dev/null @@ -1,1023 +0,0 @@ -/* - * Copyright (c) 2011-2016 Scott Vokes - * - * Permission to use, copy, modify, and/or distribute this software for any - * purpose with or without fee is hereby granted, provided that the above - * copyright notice and this permission notice appear in all copies. - * - * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef GREATEST_H -#define GREATEST_H - -#ifdef __cplusplus -extern "C" { -#endif - -/* 1.2.1 */ -#define GREATEST_VERSION_MAJOR 1 -#define GREATEST_VERSION_MINOR 2 -#define GREATEST_VERSION_PATCH 1 - -/* A unit testing system for C, contained in 1 file. - * It doesn't use dynamic allocation or depend on anything - * beyond ANSI C89. - * - * An up-to-date version can be found at: - * https://github.com/silentbicycle/greatest/ - */ - - -/********************************************************************* - * Minimal test runner template - *********************************************************************/ -#if 0 -#include "greatest.h" -TEST foo_should_foo(void) { - PASS(); -} -static void setup_cb(void *data) { - printf("setup callback for each test case\n"); -} -static void teardown_cb(void *data) { - printf("teardown callback for each test case\n"); -} -SUITE(suite) { - /* Optional setup/teardown callbacks which will be run before/after - * every test case. If using a test suite, they will be cleared when - * the suite finishes. */ - SET_SETUP(setup_cb, voidp_to_callback_data); - SET_TEARDOWN(teardown_cb, voidp_to_callback_data); - RUN_TEST(foo_should_foo); -} -/* Add definitions that need to be in the test runner's main file. */ -GREATEST_MAIN_DEFS(); -/* Set up, run suite(s) of tests, report pass/fail/skip stats. */ -int run_tests(void) { - GREATEST_INIT(); /* init. greatest internals */ - /* List of suites to run (if any). */ - RUN_SUITE(suite); - /* Tests can also be run directly, without using test suites. */ - RUN_TEST(foo_should_foo); - GREATEST_PRINT_REPORT(); /* display results */ - return greatest_all_passed(); -} -/* main(), for a standalone command-line test runner. - * This replaces run_tests above, and adds command line option - * handling and exiting with a pass/fail status. */ -int main(int argc, char **argv) { - GREATEST_MAIN_BEGIN(); /* init & parse command-line args */ - RUN_SUITE(suite); - GREATEST_MAIN_END(); /* display results */ -} -#endif -/*********************************************************************/ - - -#include -#include -#include -#include - -/*********** - * Options * - ***********/ - -/* Default column width for non-verbose output. */ -#ifndef GREATEST_DEFAULT_WIDTH -#define GREATEST_DEFAULT_WIDTH 72 -#endif - -/* FILE *, for test logging. */ -#ifndef GREATEST_STDOUT -#define GREATEST_STDOUT stdout -#endif - -/* Remove GREATEST_ prefix from most commonly used symbols? */ -#ifndef GREATEST_USE_ABBREVS -#define GREATEST_USE_ABBREVS 1 -#endif - -/* Set to 0 to disable all use of setjmp/longjmp. */ -#ifndef GREATEST_USE_LONGJMP -#define GREATEST_USE_LONGJMP 1 -#endif - -#if GREATEST_USE_LONGJMP -#include -#endif - -/* Set to 0 to disable all use of time.h / clock(). */ -#ifndef GREATEST_USE_TIME -#define GREATEST_USE_TIME 1 -#endif - -#if GREATEST_USE_TIME -#include -#endif - -/* Floating point type, for ASSERT_IN_RANGE. */ -#ifndef GREATEST_FLOAT -#define GREATEST_FLOAT double -#define GREATEST_FLOAT_FMT "%g" -#endif - -/********* - * Types * - *********/ - -/* Info for the current running suite. */ -typedef struct greatest_suite_info { - unsigned int tests_run; - unsigned int passed; - unsigned int failed; - unsigned int skipped; - -#if GREATEST_USE_TIME - /* timers, pre/post running suite and individual tests */ - clock_t pre_suite; - clock_t post_suite; - clock_t pre_test; - clock_t post_test; -#endif -} greatest_suite_info; - -/* Type for a suite function. */ -typedef void (greatest_suite_cb)(void); - -/* Types for setup/teardown callbacks. If non-NULL, these will be run - * and passed the pointer to their additional data. */ -typedef void (greatest_setup_cb)(void *udata); -typedef void (greatest_teardown_cb)(void *udata); - -/* Type for an equality comparison between two pointers of the same type. - * Should return non-0 if equal, otherwise 0. - * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ -typedef int greatest_equal_cb(const void *exp, const void *got, void *udata); - -/* Type for a callback that prints a value pointed to by T. - * Return value has the same meaning as printf's. - * UDATA is a closure value, passed through from ASSERT_EQUAL_T[m]. */ -typedef int greatest_printf_cb(const void *t, void *udata); - -/* Callbacks for an arbitrary type; needed for type-specific - * comparisons via GREATEST_ASSERT_EQUAL_T[m].*/ -typedef struct greatest_type_info { - greatest_equal_cb *equal; - greatest_printf_cb *print; -} greatest_type_info; - -typedef struct greatest_memory_cmp_env { - const unsigned char *exp; - const unsigned char *got; - size_t size; -} greatest_memory_cmp_env; - -/* Callbacks for string and raw memory types. */ -extern greatest_type_info greatest_type_info_string; -extern greatest_type_info greatest_type_info_memory; - -typedef enum { - GREATEST_FLAG_FIRST_FAIL = 0x01, - GREATEST_FLAG_LIST_ONLY = 0x02 -} greatest_flag_t; - -/* Struct containing all test runner state. */ -typedef struct greatest_run_info { - unsigned char flags; - unsigned char verbosity; - unsigned int tests_run; /* total test count */ - - /* overall pass/fail/skip counts */ - unsigned int passed; - unsigned int failed; - unsigned int skipped; - unsigned int assertions; - - /* currently running test suite */ - greatest_suite_info suite; - - /* info to print about the most recent failure */ - const char *fail_file; - unsigned int fail_line; - const char *msg; - - /* current setup/teardown hooks and userdata */ - greatest_setup_cb *setup; - void *setup_udata; - greatest_teardown_cb *teardown; - void *teardown_udata; - - /* formatting info for ".....s...F"-style output */ - unsigned int col; - unsigned int width; - - /* only run a specific suite or test */ - const char *suite_filter; - const char *test_filter; - -#if GREATEST_USE_TIME - /* overall timers */ - clock_t begin; - clock_t end; -#endif - -#if GREATEST_USE_LONGJMP - jmp_buf jump_dest; -#endif -} greatest_run_info; - -struct greatest_report_t { - /* overall pass/fail/skip counts */ - unsigned int passed; - unsigned int failed; - unsigned int skipped; - unsigned int assertions; -}; - -/* Global var for the current testing context. - * Initialized by GREATEST_MAIN_DEFS(). */ -extern greatest_run_info greatest_info; - -/* Type for ASSERT_ENUM_EQ's ENUM_STR argument. */ -typedef const char *greatest_enum_str_fun(int value); - -/********************** - * Exported functions * - **********************/ - -/* These are used internally by greatest. */ -void greatest_do_pass(const char *name); -void greatest_do_fail(const char *name); -void greatest_do_skip(const char *name); -int greatest_pre_test(const char *name); -void greatest_post_test(const char *name, int res); -void greatest_usage(const char *name); -int greatest_do_assert_equal_t(const void *exp, const void *got, - greatest_type_info *type_info, void *udata); - -/* These are part of the public greatest API. */ -void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata); -void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, void *udata); -int greatest_all_passed(void); -void greatest_set_test_filter(const char *name); -void greatest_set_suite_filter(const char *name); -void greatest_get_report(struct greatest_report_t *report); -unsigned int greatest_get_verbosity(void); -void greatest_set_verbosity(unsigned int verbosity); -void greatest_set_flag(greatest_flag_t flag); - - -/******************** -* Language Support * -********************/ - -/* If __VA_ARGS__ (C99) is supported, allow parametric testing -* without needing to manually manage the argument struct. */ -#if __STDC_VERSION__ >= 19901L || _MSC_VER >= 1800 -#define GREATEST_VA_ARGS -#endif - - -/********** - * Macros * - **********/ - -/* Define a suite. */ -#define GREATEST_SUITE(NAME) void NAME(void); void NAME(void) - -/* Declare a suite, provided by another compilation unit. */ -#define GREATEST_SUITE_EXTERN(NAME) void NAME(void) - -/* Start defining a test function. - * The arguments are not included, to allow parametric testing. */ -#define GREATEST_TEST static enum greatest_test_res - -/* PASS/FAIL/SKIP result from a test. Used internally. */ -typedef enum greatest_test_res { - GREATEST_TEST_RES_PASS = 0, - GREATEST_TEST_RES_FAIL = -1, - GREATEST_TEST_RES_SKIP = 1 -} greatest_test_res; - -/* Run a suite. */ -#define GREATEST_RUN_SUITE(S_NAME) greatest_run_suite(S_NAME, #S_NAME) - -/* Run a test in the current suite. */ -#define GREATEST_RUN_TEST(TEST) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - enum greatest_test_res res = GREATEST_SAVE_CONTEXT(); \ - if (res == GREATEST_TEST_RES_PASS) { \ - res = TEST(); \ - } \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) - -/* Ignore a test, don't warn about it being unused. */ -#define GREATEST_IGNORE_TEST(TEST) (void)TEST - -/* Run a test in the current suite with one void * argument, - * which can be a pointer to a struct with multiple arguments. */ -#define GREATEST_RUN_TEST1(TEST, ENV) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - int res = TEST(ENV); \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) - -#ifdef GREATEST_VA_ARGS -#define GREATEST_RUN_TESTp(TEST, ...) \ - do { \ - if (greatest_pre_test(#TEST) == 1) { \ - int res = TEST(__VA_ARGS__); \ - greatest_post_test(#TEST, res); \ - } else if (GREATEST_LIST_ONLY()) { \ - fprintf(GREATEST_STDOUT, " %s\n", #TEST); \ - } \ - } while (0) -#endif - - -/* Check if the test runner is in verbose mode. */ -#define GREATEST_IS_VERBOSE() ((greatest_info.verbosity) > 0) -#define GREATEST_LIST_ONLY() \ - (greatest_info.flags & GREATEST_FLAG_LIST_ONLY) -#define GREATEST_FIRST_FAIL() \ - (greatest_info.flags & GREATEST_FLAG_FIRST_FAIL) -#define GREATEST_FAILURE_ABORT() \ - (greatest_info.suite.failed > 0 && GREATEST_FIRST_FAIL()) - -/* Message-less forms of tests defined below. */ -#define GREATEST_PASS() GREATEST_PASSm(NULL) -#define GREATEST_FAIL() GREATEST_FAILm(NULL) -#define GREATEST_SKIP() GREATEST_SKIPm(NULL) -#define GREATEST_ASSERT(COND) \ - GREATEST_ASSERTm(#COND, COND) -#define GREATEST_ASSERT_OR_LONGJMP(COND) \ - GREATEST_ASSERT_OR_LONGJMPm(#COND, COND) -#define GREATEST_ASSERT_FALSE(COND) \ - GREATEST_ASSERT_FALSEm(#COND, COND) -#define GREATEST_ASSERT_EQ(EXP, GOT) \ - GREATEST_ASSERT_EQm(#EXP " != " #GOT, EXP, GOT) -#define GREATEST_ASSERT_EQ_FMT(EXP, GOT, FMT) \ - GREATEST_ASSERT_EQ_FMTm(#EXP " != " #GOT, EXP, GOT, FMT) -#define GREATEST_ASSERT_IN_RANGE(EXP, GOT, TOL) \ - GREATEST_ASSERT_IN_RANGEm(#EXP " != " #GOT " +/- " #TOL, EXP, GOT, TOL) -#define GREATEST_ASSERT_EQUAL_T(EXP, GOT, TYPE_INFO, UDATA) \ - GREATEST_ASSERT_EQUAL_Tm(#EXP " != " #GOT, EXP, GOT, TYPE_INFO, UDATA) -#define GREATEST_ASSERT_STR_EQ(EXP, GOT) \ - GREATEST_ASSERT_STR_EQm(#EXP " != " #GOT, EXP, GOT) -#define GREATEST_ASSERT_STRN_EQ(EXP, GOT, SIZE) \ - GREATEST_ASSERT_STRN_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) -#define GREATEST_ASSERT_MEM_EQ(EXP, GOT, SIZE) \ - GREATEST_ASSERT_MEM_EQm(#EXP " != " #GOT, EXP, GOT, SIZE) -#define GREATEST_ASSERT_ENUM_EQ(EXP, GOT, ENUM_STR) \ - GREATEST_ASSERT_ENUM_EQm(#EXP " != " #GOT, EXP, GOT, ENUM_STR) - -/* The following forms take an additional message argument first, - * to be displayed by the test runner. */ - -/* Fail if a condition is not true, with message. */ -#define GREATEST_ASSERTm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if (!(COND)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if a condition is not true, longjmping out of test. */ -#define GREATEST_ASSERT_OR_LONGJMPm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if (!(COND)) { GREATEST_FAIL_WITH_LONGJMPm(MSG); } \ - } while (0) - -/* Fail if a condition is not false, with message. */ -#define GREATEST_ASSERT_FALSEm(MSG, COND) \ - do { \ - greatest_info.assertions++; \ - if ((COND)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if EXP != GOT (equality comparison by ==). */ -#define GREATEST_ASSERT_EQm(MSG, EXP, GOT) \ - do { \ - greatest_info.assertions++; \ - if ((EXP) != (GOT)) { GREATEST_FAILm(MSG); } \ - } while (0) - -/* Fail if EXP != GOT (equality comparison by ==). - * Warning: EXP and GOT will be evaluated more than once on failure. */ -#define GREATEST_ASSERT_EQ_FMTm(MSG, EXP, GOT, FMT) \ - do { \ - const char *greatest_FMT = ( FMT ); \ - greatest_info.assertions++; \ - if ((EXP) != (GOT)) { \ - fprintf(GREATEST_STDOUT, "\nExpected: "); \ - fprintf(GREATEST_STDOUT, greatest_FMT, EXP); \ - fprintf(GREATEST_STDOUT, "\n Got: "); \ - fprintf(GREATEST_STDOUT, greatest_FMT, GOT); \ - fprintf(GREATEST_STDOUT, "\n"); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) - -/* Fail if EXP is not equal to GOT, printing enum IDs. */ -#define GREATEST_ASSERT_ENUM_EQm(MSG, EXP, GOT, ENUM_STR) \ - do { \ - int greatest_EXP = (int)(EXP); \ - int greatest_GOT = (int)(GOT); \ - greatest_enum_str_fun *greatest_ENUM_STR = ENUM_STR; \ - if (greatest_EXP != greatest_GOT) { \ - fprintf(GREATEST_STDOUT, "\nExpected: %s", \ - greatest_ENUM_STR(greatest_EXP)); \ - fprintf(GREATEST_STDOUT, "\n Got: %s\n", \ - greatest_ENUM_STR(greatest_GOT)); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) \ - -/* Fail if GOT not in range of EXP +|- TOL. */ -#define GREATEST_ASSERT_IN_RANGEm(MSG, EXP, GOT, TOL) \ - do { \ - GREATEST_FLOAT greatest_EXP = (EXP); \ - GREATEST_FLOAT greatest_GOT = (GOT); \ - GREATEST_FLOAT greatest_TOL = (TOL); \ - greatest_info.assertions++; \ - if ((greatest_EXP > greatest_GOT && \ - greatest_EXP - greatest_GOT > greatest_TOL) || \ - (greatest_EXP < greatest_GOT && \ - greatest_GOT - greatest_EXP > greatest_TOL)) { \ - fprintf(GREATEST_STDOUT, \ - "\nExpected: " GREATEST_FLOAT_FMT \ - " +/- " GREATEST_FLOAT_FMT \ - "\n Got: " GREATEST_FLOAT_FMT \ - "\n", \ - greatest_EXP, greatest_TOL, greatest_GOT); \ - GREATEST_FAILm(MSG); \ - } \ - } while (0) - -/* Fail if EXP is not equal to GOT, according to strcmp. */ -#define GREATEST_ASSERT_STR_EQm(MSG, EXP, GOT) \ - do { \ - GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ - &greatest_type_info_string, NULL); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to strcmp. */ -#define GREATEST_ASSERT_STRN_EQm(MSG, EXP, GOT, SIZE) \ - do { \ - size_t size = SIZE; \ - GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, \ - &greatest_type_info_string, &size); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to memcmp. */ -#define GREATEST_ASSERT_MEM_EQm(MSG, EXP, GOT, SIZE) \ - do { \ - greatest_memory_cmp_env env; \ - env.exp = (const unsigned char *)EXP; \ - env.got = (const unsigned char *)GOT; \ - env.size = SIZE; \ - GREATEST_ASSERT_EQUAL_Tm(MSG, env.exp, env.got, \ - &greatest_type_info_memory, &env); \ - } while (0) \ - -/* Fail if EXP is not equal to GOT, according to a comparison - * callback in TYPE_INFO. If they are not equal, optionally use a - * print callback in TYPE_INFO to print them. */ -#define GREATEST_ASSERT_EQUAL_Tm(MSG, EXP, GOT, TYPE_INFO, UDATA) \ - do { \ - greatest_type_info *type_info = (TYPE_INFO); \ - greatest_info.assertions++; \ - if (!greatest_do_assert_equal_t(EXP, GOT, \ - type_info, UDATA)) { \ - if (type_info == NULL || type_info->equal == NULL) { \ - GREATEST_FAILm("type_info->equal callback missing!"); \ - } else { \ - GREATEST_FAILm(MSG); \ - } \ - } \ - } while (0) \ - -/* Pass. */ -#define GREATEST_PASSm(MSG) \ - do { \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_PASS; \ - } while (0) - -/* Fail. */ -#define GREATEST_FAILm(MSG) \ - do { \ - greatest_info.fail_file = __FILE__; \ - greatest_info.fail_line = __LINE__; \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_FAIL; \ - } while (0) - -/* Optional GREATEST_FAILm variant that longjmps. */ -#if GREATEST_USE_LONGJMP -#define GREATEST_FAIL_WITH_LONGJMP() GREATEST_FAIL_WITH_LONGJMPm(NULL) -#define GREATEST_FAIL_WITH_LONGJMPm(MSG) \ - do { \ - greatest_info.fail_file = __FILE__; \ - greatest_info.fail_line = __LINE__; \ - greatest_info.msg = MSG; \ - longjmp(greatest_info.jump_dest, GREATEST_TEST_RES_FAIL); \ - } while (0) -#endif - -/* Skip the current test. */ -#define GREATEST_SKIPm(MSG) \ - do { \ - greatest_info.msg = MSG; \ - return GREATEST_TEST_RES_SKIP; \ - } while (0) - -/* Check the result of a subfunction using ASSERT, etc. */ -#define GREATEST_CHECK_CALL(RES) \ - do { \ - enum greatest_test_res greatest_RES = RES; \ - if (greatest_RES != GREATEST_TEST_RES_PASS) { \ - return greatest_RES; \ - } \ - } while (0) \ - -#if GREATEST_USE_TIME -#define GREATEST_SET_TIME(NAME) \ - NAME = clock(); \ - if (NAME == (clock_t) -1) { \ - fprintf(GREATEST_STDOUT, \ - "clock error: %s\n", #NAME); \ - exit(EXIT_FAILURE); \ - } - -#define GREATEST_CLOCK_DIFF(C1, C2) \ - fprintf(GREATEST_STDOUT, " (%lu ticks, %.3f sec)", \ - (long unsigned int) (C2) - (long unsigned int)(C1), \ - (double)((C2) - (C1)) / (1.0 * (double)CLOCKS_PER_SEC)) -#else -#define GREATEST_SET_TIME(UNUSED) -#define GREATEST_CLOCK_DIFF(UNUSED1, UNUSED2) -#endif - -#if GREATEST_USE_LONGJMP -#define GREATEST_SAVE_CONTEXT() \ - /* setjmp returns 0 (GREATEST_TEST_RES_PASS) on first call */ \ - /* so the test runs, then RES_FAIL from FAIL_WITH_LONGJMP. */ \ - ((enum greatest_test_res)(setjmp(greatest_info.jump_dest))) -#else -#define GREATEST_SAVE_CONTEXT() \ - /*a no-op, since setjmp/longjmp aren't being used */ \ - GREATEST_TEST_RES_PASS -#endif - -/* Include several function definitions in the main test file. */ -#define GREATEST_MAIN_DEFS() \ - \ -/* Is FILTER a subset of NAME? */ \ -static int greatest_name_match(const char *name, \ - const char *filter) { \ - size_t offset = 0; \ - size_t filter_len = strlen(filter); \ - while (name[offset] != '\0') { \ - if (name[offset] == filter[0]) { \ - if (0 == strncmp(&name[offset], filter, filter_len)) { \ - return 1; \ - } \ - } \ - offset++; \ - } \ - \ - return 0; \ -} \ - \ -int greatest_pre_test(const char *name) { \ - if (!GREATEST_LIST_ONLY() \ - && (!GREATEST_FIRST_FAIL() || greatest_info.suite.failed == 0) \ - && (greatest_info.test_filter == NULL || \ - greatest_name_match(name, greatest_info.test_filter))) { \ - GREATEST_SET_TIME(greatest_info.suite.pre_test); \ - if (greatest_info.setup) { \ - greatest_info.setup(greatest_info.setup_udata); \ - } \ - return 1; /* test should be run */ \ - } else { \ - return 0; /* skipped */ \ - } \ -} \ - \ -void greatest_post_test(const char *name, int res) { \ - GREATEST_SET_TIME(greatest_info.suite.post_test); \ - if (greatest_info.teardown) { \ - void *udata = greatest_info.teardown_udata; \ - greatest_info.teardown(udata); \ - } \ - \ - if (res <= GREATEST_TEST_RES_FAIL) { \ - greatest_do_fail(name); \ - } else if (res >= GREATEST_TEST_RES_SKIP) { \ - greatest_do_skip(name); \ - } else if (res == GREATEST_TEST_RES_PASS) { \ - greatest_do_pass(name); \ - } \ - greatest_info.suite.tests_run++; \ - greatest_info.col++; \ - if (GREATEST_IS_VERBOSE()) { \ - GREATEST_CLOCK_DIFF(greatest_info.suite.pre_test, \ - greatest_info.suite.post_test); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } else if (greatest_info.col % greatest_info.width == 0) { \ - fprintf(GREATEST_STDOUT, "\n"); \ - greatest_info.col = 0; \ - } \ - if (GREATEST_STDOUT == stdout) fflush(stdout); \ -} \ - \ -static void report_suite(void) { \ - if (greatest_info.suite.tests_run > 0) { \ - fprintf(GREATEST_STDOUT, \ - "\n%u test%s - %u passed, %u failed, %u skipped", \ - greatest_info.suite.tests_run, \ - greatest_info.suite.tests_run == 1 ? "" : "s", \ - greatest_info.suite.passed, \ - greatest_info.suite.failed, \ - greatest_info.suite.skipped); \ - GREATEST_CLOCK_DIFF(greatest_info.suite.pre_suite, \ - greatest_info.suite.post_suite); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } \ -} \ - \ -static void update_counts_and_reset_suite(void) { \ - greatest_info.setup = NULL; \ - greatest_info.setup_udata = NULL; \ - greatest_info.teardown = NULL; \ - greatest_info.teardown_udata = NULL; \ - greatest_info.passed += greatest_info.suite.passed; \ - greatest_info.failed += greatest_info.suite.failed; \ - greatest_info.skipped += greatest_info.suite.skipped; \ - greatest_info.tests_run += greatest_info.suite.tests_run; \ - memset(&greatest_info.suite, 0, sizeof(greatest_info.suite)); \ - greatest_info.col = 0; \ -} \ - \ -static void greatest_run_suite(greatest_suite_cb *suite_cb, \ - const char *suite_name) { \ - if (greatest_info.suite_filter && \ - !greatest_name_match(suite_name, greatest_info.suite_filter)) { \ - return; \ - } \ - update_counts_and_reset_suite(); \ - if (GREATEST_FIRST_FAIL() && greatest_info.failed > 0) { return; } \ - fprintf(GREATEST_STDOUT, "\n* Suite %s:\n", suite_name); \ - GREATEST_SET_TIME(greatest_info.suite.pre_suite); \ - suite_cb(); \ - GREATEST_SET_TIME(greatest_info.suite.post_suite); \ - report_suite(); \ -} \ - \ -void greatest_do_pass(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, "PASS %s: %s", \ - name, greatest_info.msg ? greatest_info.msg : ""); \ - } else { \ - fprintf(GREATEST_STDOUT, "."); \ - } \ - greatest_info.suite.passed++; \ -} \ - \ -void greatest_do_fail(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, \ - "FAIL %s: %s (%s:%u)", \ - name, greatest_info.msg ? greatest_info.msg : "", \ - greatest_info.fail_file, greatest_info.fail_line); \ - } else { \ - fprintf(GREATEST_STDOUT, "F"); \ - greatest_info.col++; \ - /* add linebreak if in line of '.'s */ \ - if (greatest_info.col != 0) { \ - fprintf(GREATEST_STDOUT, "\n"); \ - greatest_info.col = 0; \ - } \ - fprintf(GREATEST_STDOUT, "FAIL %s: %s (%s:%u)\n", \ - name, \ - greatest_info.msg ? greatest_info.msg : "", \ - greatest_info.fail_file, greatest_info.fail_line); \ - } \ - greatest_info.suite.failed++; \ -} \ - \ -void greatest_do_skip(const char *name) { \ - if (GREATEST_IS_VERBOSE()) { \ - fprintf(GREATEST_STDOUT, "SKIP %s: %s", \ - name, \ - greatest_info.msg ? \ - greatest_info.msg : "" ); \ - } else { \ - fprintf(GREATEST_STDOUT, "s"); \ - } \ - greatest_info.suite.skipped++; \ -} \ - \ -int greatest_do_assert_equal_t(const void *exp, const void *got, \ - greatest_type_info *type_info, void *udata) { \ - int eq = 0; \ - if (type_info == NULL || type_info->equal == NULL) { \ - return 0; \ - } \ - eq = type_info->equal(exp, got, udata); \ - if (!eq) { \ - if (type_info->print != NULL) { \ - fprintf(GREATEST_STDOUT, "\nExpected: "); \ - (void)type_info->print(exp, udata); \ - fprintf(GREATEST_STDOUT, "\n Got: "); \ - (void)type_info->print(got, udata); \ - fprintf(GREATEST_STDOUT, "\n"); \ - } else { \ - fprintf(GREATEST_STDOUT, \ - "GREATEST_ASSERT_EQUAL_T failure at %s:%u\n", \ - greatest_info.fail_file, \ - greatest_info.fail_line); \ - } \ - } \ - return eq; \ -} \ - \ -void greatest_usage(const char *name) { \ - fprintf(GREATEST_STDOUT, \ - "Usage: %s [-hlfv] [-s SUITE] [-t TEST]\n" \ - " -h, --help print this Help\n" \ - " -l List suites and their tests, then exit\n" \ - " -f Stop runner after first failure\n" \ - " -v Verbose output\n" \ - " -s SUITE only run suites containing string SUITE\n" \ - " -t TEST only run tests containing string TEST\n", \ - name); \ -} \ - \ -static void greatest_parse_args(int argc, char **argv) { \ - int i = 0; \ - for (i = 1; i < argc; i++) { \ - if (0 == strncmp("-t", argv[i], 2)) { \ - if (argc <= i + 1) { \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - greatest_info.test_filter = argv[i+1]; \ - i++; \ - } else if (0 == strncmp("-s", argv[i], 2)) { \ - if (argc <= i + 1) { \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - greatest_info.suite_filter = argv[i+1]; \ - i++; \ - } else if (0 == strncmp("-f", argv[i], 2)) { \ - greatest_info.flags |= GREATEST_FLAG_FIRST_FAIL; \ - } else if (0 == strncmp("-v", argv[i], 2)) { \ - greatest_info.verbosity++; \ - } else if (0 == strncmp("-l", argv[i], 2)) { \ - greatest_info.flags |= GREATEST_FLAG_LIST_ONLY; \ - } else if (0 == strncmp("-h", argv[i], 2) || \ - 0 == strncmp("--help", argv[i], 6)) { \ - greatest_usage(argv[0]); \ - exit(EXIT_SUCCESS); \ - } else if (0 == strncmp("--", argv[i], 2)) { \ - break; \ - } else { \ - fprintf(GREATEST_STDOUT, \ - "Unknown argument '%s'\n", argv[i]); \ - greatest_usage(argv[0]); \ - exit(EXIT_FAILURE); \ - } \ - } \ -} \ - \ -int greatest_all_passed(void) { return (greatest_info.failed == 0); } \ - \ -void greatest_set_test_filter(const char *name) { \ - greatest_info.test_filter = name; \ -} \ - \ -void greatest_set_suite_filter(const char *name) { \ - greatest_info.suite_filter = name; \ -} \ - \ -void greatest_get_report(struct greatest_report_t *report) { \ - if (report) { \ - report->passed = greatest_info.passed; \ - report->failed = greatest_info.failed; \ - report->skipped = greatest_info.skipped; \ - report->assertions = greatest_info.assertions; \ - } \ -} \ - \ -unsigned int greatest_get_verbosity(void) { \ - return greatest_info.verbosity; \ -} \ - \ -void greatest_set_verbosity(unsigned int verbosity) { \ - greatest_info.verbosity = (unsigned char)verbosity; \ -} \ - \ -void greatest_set_flag(greatest_flag_t flag) { \ - greatest_info.flags |= flag; \ -} \ - \ -void GREATEST_SET_SETUP_CB(greatest_setup_cb *cb, void *udata) { \ - greatest_info.setup = cb; \ - greatest_info.setup_udata = udata; \ -} \ - \ -void GREATEST_SET_TEARDOWN_CB(greatest_teardown_cb *cb, \ - void *udata) { \ - greatest_info.teardown = cb; \ - greatest_info.teardown_udata = udata; \ -} \ - \ -static int greatest_string_equal_cb(const void *exp, const void *got, \ - void *udata) { \ - size_t *size = (size_t *)udata; \ - return (size != NULL \ - ? (0 == strncmp((const char *)exp, (const char *)got, *size)) \ - : (0 == strcmp((const char *)exp, (const char *)got))); \ -} \ - \ -static int greatest_string_printf_cb(const void *t, void *udata) { \ - (void)udata; /* note: does not check \0 termination. */ \ - return fprintf(GREATEST_STDOUT, "%s", (const char *)t); \ -} \ - \ -greatest_type_info greatest_type_info_string = { \ - greatest_string_equal_cb, \ - greatest_string_printf_cb, \ -}; \ - \ -static int greatest_memory_equal_cb(const void *exp, const void *got, \ - void *udata) { \ - greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ - return (0 == memcmp(exp, got, env->size)); \ -} \ - \ -static int greatest_memory_printf_cb(const void *t, void *udata) { \ - greatest_memory_cmp_env *env = (greatest_memory_cmp_env *)udata; \ - unsigned char *buf = (unsigned char *)t, diff_mark = ' '; \ - FILE *out = GREATEST_STDOUT; \ - size_t i, line_i, line_len = 0; \ - int len = 0; /* format hexdump with differences highlighted */ \ - for (i = 0; i < env->size; i+= line_len) { \ - diff_mark = ' '; \ - line_len = env->size - i; \ - if (line_len > 16) { line_len = 16; } \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - if (env->exp[line_i] != env->got[line_i]) diff_mark = 'X'; \ - } \ - len += fprintf(out, "\n%04x %c ", (unsigned int)i, diff_mark); \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - int m = env->exp[line_i] == env->got[line_i]; /* match? */ \ - len += fprintf(out, "%02x%c", buf[line_i], m ? ' ' : '<'); \ - } \ - for (line_i = 0; line_i < 16 - line_len; line_i++) { \ - len += fprintf(out, " "); \ - } \ - fprintf(out, " "); \ - for (line_i = i; line_i < i + line_len; line_i++) { \ - unsigned char c = buf[line_i]; \ - len += fprintf(out, "%c", isprint(c) ? c : '.'); \ - } \ - } \ - len += fprintf(out, "\n"); \ - return len; \ -} \ - \ -greatest_type_info greatest_type_info_memory = { \ - greatest_memory_equal_cb, \ - greatest_memory_printf_cb, \ -}; \ - \ -greatest_run_info greatest_info - -/* Init internals. */ -#define GREATEST_INIT() \ - do { \ - /* Suppress unused function warning if features aren't used */ \ - (void)greatest_run_suite; \ - (void)greatest_parse_args; \ - \ - memset(&greatest_info, 0, sizeof(greatest_info)); \ - greatest_info.width = GREATEST_DEFAULT_WIDTH; \ - GREATEST_SET_TIME(greatest_info.begin); \ - } while (0) \ - -/* Handle command-line arguments, etc. */ -#define GREATEST_MAIN_BEGIN() \ - do { \ - GREATEST_INIT(); \ - greatest_parse_args(argc, argv); \ - } while (0) - -/* Report passes, failures, skipped tests, the number of - * assertions, and the overall run time. */ -#define GREATEST_PRINT_REPORT() \ - do { \ - if (!GREATEST_LIST_ONLY()) { \ - update_counts_and_reset_suite(); \ - GREATEST_SET_TIME(greatest_info.end); \ - fprintf(GREATEST_STDOUT, \ - "\nTotal: %u test%s", \ - greatest_info.tests_run, \ - greatest_info.tests_run == 1 ? "" : "s"); \ - GREATEST_CLOCK_DIFF(greatest_info.begin, \ - greatest_info.end); \ - fprintf(GREATEST_STDOUT, ", %u assertion%s\n", \ - greatest_info.assertions, \ - greatest_info.assertions == 1 ? "" : "s"); \ - fprintf(GREATEST_STDOUT, \ - "Pass: %u, fail: %u, skip: %u.\n", \ - greatest_info.passed, \ - greatest_info.failed, greatest_info.skipped); \ - } \ - } while (0) - -/* Report results, exit with exit status based on results. */ -#define GREATEST_MAIN_END() \ - do { \ - GREATEST_PRINT_REPORT(); \ - return (greatest_all_passed() ? EXIT_SUCCESS : EXIT_FAILURE); \ - } while (0) - -/* Make abbreviations without the GREATEST_ prefix for the - * most commonly used symbols. */ -#if GREATEST_USE_ABBREVS -#define TEST GREATEST_TEST -#define SUITE GREATEST_SUITE -#define SUITE_EXTERN GREATEST_SUITE_EXTERN -#define RUN_TEST GREATEST_RUN_TEST -#define RUN_TEST1 GREATEST_RUN_TEST1 -#define RUN_SUITE GREATEST_RUN_SUITE -#define IGNORE_TEST GREATEST_IGNORE_TEST -#define ASSERT GREATEST_ASSERT -#define ASSERTm GREATEST_ASSERTm -#define ASSERT_FALSE GREATEST_ASSERT_FALSE -#define ASSERT_EQ GREATEST_ASSERT_EQ -#define ASSERT_EQ_FMT GREATEST_ASSERT_EQ_FMT -#define ASSERT_IN_RANGE GREATEST_ASSERT_IN_RANGE -#define ASSERT_EQUAL_T GREATEST_ASSERT_EQUAL_T -#define ASSERT_STR_EQ GREATEST_ASSERT_STR_EQ -#define ASSERT_STRN_EQ GREATEST_ASSERT_STRN_EQ -#define ASSERT_MEM_EQ GREATEST_ASSERT_MEM_EQ -#define ASSERT_ENUM_EQ GREATEST_ASSERT_ENUM_EQ -#define ASSERT_FALSEm GREATEST_ASSERT_FALSEm -#define ASSERT_EQm GREATEST_ASSERT_EQm -#define ASSERT_EQ_FMTm GREATEST_ASSERT_EQ_FMTm -#define ASSERT_IN_RANGEm GREATEST_ASSERT_IN_RANGEm -#define ASSERT_EQUAL_Tm GREATEST_ASSERT_EQUAL_Tm -#define ASSERT_STR_EQm GREATEST_ASSERT_STR_EQm -#define ASSERT_STRN_EQm GREATEST_ASSERT_STRN_EQm -#define ASSERT_MEM_EQm GREATEST_ASSERT_MEM_EQm -#define ASSERT_ENUM_EQm GREATEST_ASSERT_ENUM_EQm -#define PASS GREATEST_PASS -#define FAIL GREATEST_FAIL -#define SKIP GREATEST_SKIP -#define PASSm GREATEST_PASSm -#define FAILm GREATEST_FAILm -#define SKIPm GREATEST_SKIPm -#define SET_SETUP GREATEST_SET_SETUP_CB -#define SET_TEARDOWN GREATEST_SET_TEARDOWN_CB -#define CHECK_CALL GREATEST_CHECK_CALL - -#ifdef GREATEST_VA_ARGS -#define RUN_TESTp GREATEST_RUN_TESTp -#endif - -#if GREATEST_USE_LONGJMP -#define ASSERT_OR_LONGJMP GREATEST_ASSERT_OR_LONGJMP -#define ASSERT_OR_LONGJMPm GREATEST_ASSERT_OR_LONGJMPm -#define FAIL_WITH_LONGJMP GREATEST_FAIL_WITH_LONGJMP -#define FAIL_WITH_LONGJMPm GREATEST_FAIL_WITH_LONGJMPm -#endif - -#endif /* USE_ABBREVS */ - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/src/common/thirdparty/patches/.gitattributes b/src/common/thirdparty/patches/.gitattributes deleted file mode 100644 index 9812ceb1ffd9..000000000000 --- a/src/common/thirdparty/patches/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -*.patch text eol=lf diff --git a/src/common/thirdparty/patches/windows/python-pyconfig.patch b/src/common/thirdparty/patches/windows/python-pyconfig.patch deleted file mode 100644 index 4280dee77470..000000000000 --- a/src/common/thirdparty/patches/windows/python-pyconfig.patch +++ /dev/null @@ -1,25 +0,0 @@ -diff --git a/inc/Windows/pyconfig.h b/inc/Windows/pyconfig.h -index 1cfc59b..d4861cb ---- a/inc/Windows/pyconfig.h -+++ b/inc/Windows/pyconfig.h -@@ -1,6 +1,11 @@ - #ifndef Py_CONFIG_H - #define Py_CONFIG_H - -+#ifdef _MSC_VER -+#pragma push_macro("_DEBUG") -+#undef _DEBUG -+#endif -+ - /* pyconfig.h. NOT Generated automatically by configure. - - This is a manually maintained version used for the Watcom, -@@ -756,4 +761,8 @@ Py_NO_ENABLE_SHARED to find out. Also support MS_NO_COREDLL for b/w compat */ - least significant byte first */ - #define DOUBLE_IS_LITTLE_ENDIAN_IEEE754 1 - -+#ifdef _MSC_VER -+#pragma pop_macro("_DEBUG") -+#endif -+ - #endif /* !Py_CONFIG_H */ diff --git a/src/common/thirdparty/patches/windows/redis.patch b/src/common/thirdparty/patches/windows/redis.patch deleted file mode 100644 index 5ed2df5105cf..000000000000 --- a/src/common/thirdparty/patches/windows/redis.patch +++ /dev/null @@ -1,772 +0,0 @@ -diff --git a/msvs/RedisServer.vcxproj b/msvs/RedisServer.vcxproj -index 115ce90..68afb44 ---- a/msvs/RedisServer.vcxproj -+++ b/msvs/RedisServer.vcxproj -@@ -24,26 +24,26 @@ - - - -- Application -+ StaticLibrary - true -- v120 -+ v140_xp - false - - -- Application -+ StaticLibrary - true -- v120 -+ v140_xp - false - - -- Application -+ StaticLibrary - false -- v120 -+ v140_xp - - -- Application -+ StaticLibrary - false -- v120 -+ v140_xp - - - -@@ -61,41 +61,23 @@ - - - -- -+ - false - redis-server - false -- -- -- false -- redis-server -- false -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- false -- redis-server -- false -- Build -- -- -- false -- redis-server -- false -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - Disabled - 4996;4146 -- true -+ false -+ true - - - true -@@ -109,14 +91,14 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;_DEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - Disabled - 4996;4146 -- true -+ false -+ true - - - true -@@ -130,14 +112,13 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions) -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - 4996;4146 -- true - Full -+ true - - - true -@@ -162,13 +143,12 @@ - - - -- USE_JEMALLOC;_OFF_T_DEFINED;_WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -- $(SolutionDir)..\deps\lua\src;$(SolutionDir)..\deps\hiredis;$(SolutionDir)..\deps\jemalloc-win\include -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;_WIN32;LACKS_STDLIB_H;NDEBUG;_CONSOLE;__x86_64__;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501 -+ $(ProjectDir)..\deps\lua\src;$(ProjectDir)..\deps\hiredis;$(ProjectDir)..\deps\jemalloc-win\include - Level3 - ProgramDatabase - 4996;4146 -- true -+ true - - - true -@@ -271,9 +251,6 @@ - - - -- -- {8b897e33-6428-4254-8335-4911d179bad1} -- - - {8c07f811-c81c-432c-b334-1ae6faecf951} - -diff --git a/msvs/hiredis/hiredis.vcxproj b/msvs/hiredis/hiredis.vcxproj -index 0622958..efaedae ---- a/msvs/hiredis/hiredis.vcxproj -+++ b/msvs/hiredis/hiredis.vcxproj -@@ -28,27 +28,25 @@ - StaticLibrary - true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false -- true - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false -- true - MultiByte -- v120 -+ v140_xp - - - -@@ -66,30 +64,20 @@ - - - -- -+ - hiredis -- -- -- hiredis -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- hiredis -- -- -- hiredis -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(ProjectDir)..\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - - NotUsing - Level3 - Disabled -- _OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) - 4996 -+ false -+ true - - - Windows -@@ -101,9 +89,10 @@ - NotUsing - Level3 - Disabled -- _OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) -- MultiThreadedDebug -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;_DEBUG;%(PreprocessorDefinitions) - 4996 -+ false -+ true - - - Windows -@@ -117,10 +106,9 @@ - Full - true - true -- _OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) - 4996 -- true -+ true - - - Windows -@@ -136,10 +124,9 @@ - Full - true - true -- _OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) -- MultiThreaded -+ _WIN32_WINNT=0x0502;_OFF_T_DEFINED;WIN32;_LIB;%(PreprocessorDefinitions) - 4996 -- true -+ true - - - Windows -diff --git a/msvs/lua/lua/lua.vcxproj b/msvs/lua/lua/lua.vcxproj -index b187130..adef07b ---- a/msvs/lua/lua/lua.vcxproj -+++ b/msvs/lua/lua/lua.vcxproj -@@ -30,28 +30,28 @@ - true - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - true - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false - false - MultiByte -- v120 -+ v140_xp - - - StaticLibrary - false - false - MultiByte -- v120 -+ v140_xp - - - -@@ -69,25 +69,16 @@ - - - -- -+ - true -- .lib -- -- -- true -- .lib -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ - -- -+ - false -- .lib - -- -- false -+ - .lib -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -@@ -95,8 +86,9 @@ - Disabled - _OFF_T_DEFINED;WIN32;_DEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreadedDebug - 4244;4018 -+ false -+ true - - - true -@@ -110,8 +102,9 @@ - Disabled - _OFF_T_DEFINED;WIN32;_DEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501;LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreadedDebug - 4244;4018 -+ false -+ true - - - true -@@ -124,10 +117,10 @@ - Level3 - _OFF_T_DEFINED;WIN32;NDEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreaded - 4244;4018 - Full - true -+ true - - - true -@@ -140,8 +133,8 @@ - Level3 - _OFF_T_DEFINED;WIN32;NDEBUG;_LIB;_CRT_SECURE_NO_WARNINGS;%(PreprocessorDefinitions);_WIN32_WINNT=0x0501;LUA_ANSI;ENABLE_CJSON_GLOBAL - NotUsing -- MultiThreaded - 4244;4018 -+ true - - - true -diff --git a/src/Win32_Interop/Win32_ANSI.c b/src/Win32_Interop/Win32_ANSI.c -index 404b84f..e7c55d2 ---- a/src/Win32_Interop/Win32_ANSI.c -+++ b/src/Win32_Interop/Win32_ANSI.c -@@ -737,7 +737,7 @@ void ANSI_printf(char *format, ...) { - memset(buffer, 0, cBufLen); - - va_start(args, format); -- retVal = vsprintf_s(buffer, cBufLen, format, args); -+ retVal = vsnprintf(buffer, cBufLen - 1, format, args); - va_end(args); - - if (retVal > 0) { -diff --git a/src/Win32_Interop/Win32_EventLog.cpp b/src/Win32_Interop/Win32_EventLog.cpp -index 1856540..3db4ddd ---- a/src/Win32_Interop/Win32_EventLog.cpp -+++ b/src/Win32_Interop/Win32_EventLog.cpp -@@ -30,7 +30,6 @@ using namespace std; - - #include "Win32_EventLog.h" - #include "Win32_SmartHandle.h" --#include "EventLog.h" - - static bool eventLogEnabled = true; - static string eventLogIdentity = "redis"; -@@ -129,17 +128,17 @@ void RedisEventLog::LogMessage(LPCSTR msg, const WORD type) { - DWORD eventID; - switch (type) { - case EVENTLOG_ERROR_TYPE: -- eventID = MSG_ERROR_1; -+ eventID = 0x2; - break; - case EVENTLOG_WARNING_TYPE: -- eventID = MSG_WARNING_1; -+ eventID = 0x1; - break; - case EVENTLOG_INFORMATION_TYPE: -- eventID = MSG_INFO_1; -+ eventID = 0x0; - break; - default: - std::cerr << "Unrecognized type: " << type << "\n"; -- eventID = MSG_INFO_1; -+ eventID = 0x0; - break; - } - -diff --git a/src/Win32_Interop/Win32_FDAPI.cpp b/src/Win32_Interop/Win32_FDAPI.cpp -index 3df9af1..f60e3d4 ---- a/src/Win32_Interop/Win32_FDAPI.cpp -+++ b/src/Win32_Interop/Win32_FDAPI.cpp -@@ -46,11 +46,13 @@ fdapi_access access = NULL; - fdapi_bind bind = NULL; - fdapi_connect connect = NULL; - fdapi_fcntl fcntl = NULL; -+fdapi_ioctl ioctl = NULL; - fdapi_fstat fdapi_fstat64 = NULL; - fdapi_fsync fsync = NULL; - fdapi_ftruncate ftruncate = NULL; - fdapi_freeaddrinfo freeaddrinfo = NULL; - fdapi_getaddrinfo getaddrinfo = NULL; -+fdapi_gethostbyname gethostbyname = NULL; - fdapi_getpeername getpeername = NULL; - fdapi_getsockname getsockname = NULL; - fdapi_getsockopt getsockopt = NULL; -@@ -67,7 +69,9 @@ fdapi_open open = NULL; - fdapi_pipe pipe = NULL; - fdapi_poll poll = NULL; - fdapi_read read = NULL; -+fdapi_recv recv = NULL; - fdapi_select select = NULL; -+fdapi_send send = NULL; - fdapi_setsockopt setsockopt = NULL; - fdapi_socket socket = NULL; - fdapi_write write = NULL; -@@ -622,6 +626,23 @@ int FDAPI_fcntl(int rfd, int cmd, int flags = 0 ) { - return -1; - } - -+int FDAPI_ioctl(int rfd, int cmd, char *buf) { -+ try { -+ SocketInfo* socket_info = RFDMap::getInstance().lookupSocketInfo(rfd); -+ if (socket_info != NULL && socket_info->socket != INVALID_SOCKET) { -+ if (f_ioctlsocket(socket_info->socket, cmd, (u_long *)buf) != SOCKET_ERROR) { -+ return 0; -+ } else { -+ errno = f_WSAGetLastError(); -+ return -1; -+ } -+ } -+ } CATCH_AND_REPORT(); -+ -+ errno = EBADF; -+ return -1; -+} -+ - int FDAPI_poll(struct pollfd *fds, nfds_t nfds, int timeout) { - try { - struct pollfd* pollCopy = new struct pollfd[nfds]; -@@ -777,6 +798,42 @@ ssize_t FDAPI_read(int rfd, void *buf, size_t count) { - return -1; - } - -+ssize_t FDAPI_recv(int rfd, void *buf, size_t count, int flags) { -+ try { -+ SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -+ if (socket != INVALID_SOCKET) { -+ int retval = f_recv(socket, (char*) buf, (unsigned int) count, flags); -+ if (retval == -1) { -+ errno = GetLastError(); -+ if (errno == WSAEWOULDBLOCK) { -+ errno = EAGAIN; -+ } -+ } -+ return retval; -+ } -+ } CATCH_AND_REPORT(); -+ errno = EBADF; -+ return -1; -+} -+ -+ssize_t FDAPI_send(int rfd, const void *buf, size_t count, int flags) { -+ try { -+ SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -+ if (socket != INVALID_SOCKET) { -+ int retval = f_send(socket, (const char*) buf, (unsigned int) count, flags); -+ if (retval == -1) { -+ errno = GetLastError(); -+ if (errno == WSAEWOULDBLOCK) { -+ errno = EAGAIN; -+ } -+ } -+ return retval; -+ } -+ } CATCH_AND_REPORT(); -+ errno = EBADF; -+ return -1; -+} -+ - ssize_t FDAPI_write(int rfd, const void *buf, size_t count) { - try { - SOCKET socket = RFDMap::getInstance().lookupSocket(rfd); -@@ -1195,12 +1252,14 @@ private: - bind = FDAPI_bind; - connect = FDAPI_connect; - fcntl = FDAPI_fcntl; -+ ioctl = FDAPI_ioctl; - fdapi_fstat64 = (fdapi_fstat) FDAPI_fstat64; - freeaddrinfo = FDAPI_freeaddrinfo; - fsync = FDAPI_fsync; - ftruncate = FDAPI_ftruncate; - getaddrinfo = FDAPI_getaddrinfo; - getsockopt = FDAPI_getsockopt; -+ gethostbyname = FDAPI_gethostbyname; - getpeername = FDAPI_getpeername; - getsockname = FDAPI_getsockname; - htonl = FDAPI_htonl; -@@ -1216,9 +1275,11 @@ private: - pipe = FDAPI_pipe; - poll = FDAPI_poll; - read = FDAPI_read; -+ recv = FDAPI_recv; - select = FDAPI_select; - setsockopt = FDAPI_setsockopt; - socket = FDAPI_socket; -+ send = FDAPI_send; - write = FDAPI_write; - } - -diff --git a/src/Win32_Interop/Win32_FDAPI.h b/src/Win32_Interop/Win32_FDAPI.h -index 8fae9c7..6e09596 ---- a/src/Win32_Interop/Win32_FDAPI.h -+++ b/src/Win32_Interop/Win32_FDAPI.h -@@ -116,9 +116,12 @@ typedef int (*fdapi_open)(const char * _Filename, int _OpenFlag, int flags); - typedef int (*fdapi_accept)(int sockfd, struct sockaddr *addr, socklen_t *addrlen); - typedef int (*fdapi_setsockopt)(int sockfd, int level, int optname,const void *optval, socklen_t optlen); - typedef int (*fdapi_fcntl)(int fd, int cmd, int flags); -+typedef int (*fdapi_ioctl)(int fd, int cmd, char *buf); - typedef int (*fdapi_poll)(struct pollfd *fds, nfds_t nfds, int timeout); - typedef int (*fdapi_getsockopt)(int sockfd, int level, int optname, void *optval, socklen_t *optlen); - typedef int (*fdapi_connect)(int sockfd, const struct sockaddr *addr, size_t addrlen); -+typedef ssize_t (*fdapi_recv)(int fd, void *buf, size_t count, int flags); -+typedef ssize_t (*fdapi_send)(int rfd, void const *buf, size_t count, int flags); - typedef ssize_t (*fdapi_read)(int fd, void *buf, size_t count); - typedef ssize_t (*fdapi_write)(int fd, const void *buf, size_t count); - typedef int (*fdapi_fsync)(int fd); -@@ -128,6 +131,7 @@ typedef int (*fdapi_bind)(int sockfd, const struct sockaddr *addr, socklen_t add - typedef u_short (*fdapi_htons)(u_short hostshort); - typedef u_long (*fdapi_htonl)(u_long hostlong); - typedef u_short (*fdapi_ntohs)(u_short netshort); -+typedef struct hostent* (*fdapi_gethostbyname)(const char *name); - typedef int (*fdapi_getpeername)(int sockfd, struct sockaddr *addr, socklen_t * addrlen); - typedef int (*fdapi_getsockname)(int sockfd, struct sockaddr* addrsock, int* addrlen ); - typedef void (*fdapi_freeaddrinfo)(struct addrinfo *ai); -@@ -159,12 +163,14 @@ extern fdapi_access access; - extern fdapi_bind bind; - extern fdapi_connect connect; - extern fdapi_fcntl fcntl; -+extern fdapi_ioctl ioctl; - extern fdapi_fstat fdapi_fstat64; - extern fdapi_freeaddrinfo freeaddrinfo; - extern fdapi_fsync fsync; - extern fdapi_ftruncate ftruncate; - extern fdapi_getaddrinfo getaddrinfo; - extern fdapi_getsockopt getsockopt; -+extern fdapi_gethostbyname gethostbyname; - extern fdapi_getpeername getpeername; - extern fdapi_getsockname getsockname; - extern fdapi_htonl htonl; -@@ -180,7 +186,9 @@ extern fdapi_open open; - extern fdapi_pipe pipe; - extern fdapi_poll poll; - extern fdapi_read read; -+extern fdapi_recv recv; - extern fdapi_select select; -+extern fdapi_send send; - extern fdapi_setsockopt setsockopt; - extern fdapi_socket socket; - extern fdapi_write write; -diff --git a/src/Win32_Interop/Win32_Interop.vcxproj b/src/Win32_Interop/Win32_Interop.vcxproj -index 93fc44b..b75d89b ---- a/src/Win32_Interop/Win32_Interop.vcxproj -+++ b/src/Win32_Interop/Win32_Interop.vcxproj -@@ -74,35 +74,6 @@ - - - -- -- -- Document -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- EventLog.h -- EventLog.h -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- md resources --mc.exe -A -b -c -h . -r resources EventLog.mc --rc.exe -foresources/EventLog.res resources/EventLog.rc --link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll -- -- EventLog.h -- EventLog.h -- -- - - {8C07F811-C81C-432C-B334-1AE6FAECF951} - Win32Proj -@@ -113,27 +84,25 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - StaticLibrary - true -- v120 -+ v140_xp - Unicode - - - StaticLibrary - true -- v120 -+ v140_xp - Unicode - - - StaticLibrary - false -- v120 -- true -+ v140_xp - Unicode - - - StaticLibrary - false -- v120 -- true -+ v140_xp - Unicode - - -@@ -152,13 +121,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - - -- -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -- -- -- $(SolutionDir)$(Platform)\$(Configuration)\ -- $(Platform)\$(Configuration)\ -+ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\ -+ $(SolutionDir)build\$(Platform)\$(Configuration)\$(MSBuildProjectName)\ - - - -@@ -166,9 +131,10 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - Level3 - Disabled -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreadedDebug -+ false -+ true - - - Windows -@@ -186,9 +152,10 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - - Level3 - Disabled -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;_DEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreadedDebug -+ false -+ true - - - Windows -@@ -211,10 +178,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - Full - true - true -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreaded -- true -+ true - - - Windows -@@ -235,9 +201,9 @@ link.exe -dll -noentry resources/EventLog.res -out:$(TargetDir)EventLog.dll - MaxSpeed - true - true -- USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 -+ _WIN32_WINNT=0x0502;USE_STATIC;USE_JEMALLOC;_OFF_T_DEFINED;_NO_CRT_STDIO_INLINE;_CRT_SECURE_NO_DEPRECATE;WIN32;NDEBUG;_LIB;%(PreprocessorDefinitions);LACKS_STDLIB_H;_CRT_SECURE_NO_WARNINGS;PSAPI_VERSION=1;_WIN32_WINNT=0x0501 - $(ProjectDir)..\..\deps\lua\src;$(ProjectDir)..\..\deps\jemalloc-win\include -- MultiThreaded -+ true - - - Windows -diff --git a/src/Win32_Interop/Win32_service.cpp b/src/Win32_Interop/Win32_service.cpp -index 488538e..1c33f53 ---- a/src/Win32_Interop/Win32_service.cpp -+++ b/src/Win32_Interop/Win32_service.cpp -@@ -59,7 +59,6 @@ this should preceed the other arguments passed to redis. For instance: - #include - #include - #include --#include - #include - #include "Win32_EventLog.h" - #include -diff --git a/src/ziplist.c b/src/ziplist.c -index 24b0a7c..29d445d ---- a/src/ziplist.c -+++ b/src/ziplist.c -@@ -920,7 +920,7 @@ void ziplistRepr(unsigned char *zl) { - entry = zipEntry(p); - printf( - "{" -- "addr 0x%08lx, " /* TODO" verify 0x%08lx */ -+ "addr %p, " - "index %2d, " - "offset %5ld, " - "rl: %5u, " -@@ -929,9 +929,9 @@ void ziplistRepr(unsigned char *zl) { - "pls: %2u, " - "payload %5u" - "} ", -- (PORT_ULONG)p, -+ (void *)p, - index, -- (PORT_ULONG)(p-zl), -+ (long)(p-zl), - entry.headersize+entry.len, - entry.headersize, - entry.prevrawlen, diff --git a/src/global_scheduler/CMakeLists.txt b/src/global_scheduler/CMakeLists.txt deleted file mode 100644 index fec7ec2810d9..000000000000 --- a/src/global_scheduler/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(global_scheduler) - -include_directories(${CMAKE_CURRENT_LIST_DIR}) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -add_executable(global_scheduler global_scheduler.cc global_scheduler_algorithm.cc) - -# Make sure ${HIREDIS_LIB} is ready before linking. -add_dependencies(global_scheduler hiredis common) - -target_link_libraries(global_scheduler common ${HIREDIS_LIB} ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY} pthread) diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc deleted file mode 100644 index d964401ae720..000000000000 --- a/src/global_scheduler/global_scheduler.cc +++ /dev/null @@ -1,492 +0,0 @@ -#include -#include -#include - -#include "common.h" -#include "event_loop.h" -#include "global_scheduler.h" -#include "global_scheduler_algorithm.h" -#include "net.h" -#include "ray/util/util.h" -#include "state/db_client_table.h" -#include "state/local_scheduler_table.h" -#include "state/object_table.h" -#include "state/table.h" -#include "state/task_table.h" - -/** - * Retry the task assignment. If the local scheduler that the task is assigned - * to is no longer active, do not retry the assignment. - * TODO(rkn): We currently only retry the method if the global scheduler - * publishes a task to a local scheduler before the local scheduler has - * subscribed to the channel. If we enforce that ordering, we can remove this - * retry method. - * - * @param id The task ID. - * @param user_context The global scheduler state. - * @param user_data The Task that failed to be assigned. - * @return Void. - */ -void assign_task_to_local_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::SCHEDULED); - - // If the local scheduler has died since we requested the task assignment, do - // not retry again. - DBClientID local_scheduler_id = Task_local_scheduler(task); - auto it = state->local_schedulers.find(local_scheduler_id); - if (it == state->local_schedulers.end()) { - return; - } - - // The local scheduler is still alive. The failure is most likely due to the - // task assignment getting published before the local scheduler subscribed to - // the channel. Retry the assignment. - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = assign_task_to_local_scheduler_retry, - }; - task_table_update(state->db, Task_copy(task), &retryInfo, NULL, user_context); -} - -/** - * Assign the given task to the local scheduler, update Redis and scheduler data - * structures. - * - * @param state Global scheduler state. - * @param task Task to be assigned to the local scheduler. - * @param local_scheduler_id DB client ID for the local scheduler. - * @return Void. - */ -void assign_task_to_local_scheduler(GlobalSchedulerState *state, - Task *task, - DBClientID local_scheduler_id) { - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - RAY_LOG(DEBUG) << "assigning task to local_scheduler_id = " - << local_scheduler_id; - Task_set_state(task, TaskStatus::SCHEDULED); - Task_set_local_scheduler(task, local_scheduler_id); - RAY_LOG(DEBUG) << "Issuing a task table update for task = " - << Task_task_id(task); - - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = assign_task_to_local_scheduler_retry, - }; - task_table_update(state->db, Task_copy(task), &retryInfo, NULL, state); - - /* Update the object table info to reflect the fact that the results of this - * task will be created on the machine that the task was assigned to. This can - * be used to improve locality-aware scheduling. */ - for (int64_t i = 0; i < TaskSpec_num_returns(spec); ++i) { - ObjectID return_id = TaskSpec_return(spec, i); - if (state->scheduler_object_info_table.find(return_id) == - state->scheduler_object_info_table.end()) { - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[return_id]; - /* The value -1 indicates that the size of the object is not known yet. */ - obj_info_entry.data_size = -1; - } - RAY_CHECK(state->local_scheduler_plasma_map.count(local_scheduler_id) == 1); - state->scheduler_object_info_table[return_id].object_locations.push_back( - state->local_scheduler_plasma_map[local_scheduler_id]); - } - - /* TODO(rkn): We should probably pass around local_scheduler struct pointers - * instead of db_client_id objects. */ - /* Update the local scheduler info. */ - auto it = state->local_schedulers.find(local_scheduler_id); - RAY_CHECK(it != state->local_schedulers.end()); - - LocalScheduler &local_scheduler = it->second; - local_scheduler.num_tasks_sent += 1; - local_scheduler.num_recent_tasks_sent += 1; - // Resource accounting update for this local scheduler. - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - // The local scheduler must have this resource because otherwise we wouldn't - // be assigning the task to this local scheduler. - RAY_CHECK(local_scheduler.info.dynamic_resources.count(resource_name) == - 1 || - resource_quantity == 0); - // Subtract task's resource from the cached dynamic resource capacity for - // this local scheduler. This will be overwritten on the next heartbeat. - local_scheduler.info.dynamic_resources[resource_name] = - MAX(0, local_scheduler.info.dynamic_resources[resource_name] - - resource_quantity); - } -} - -GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, - const char *node_ip_address, - const char *redis_primary_addr, - int redis_primary_port) { - GlobalSchedulerState *state = new GlobalSchedulerState(); - state->loop = loop; - state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, - "global_scheduler", node_ip_address, - std::vector()); - db_attach(state->db, loop, false); - state->policy_state = GlobalSchedulerPolicyState_init(); - return state; -} - -void GlobalSchedulerState_free(GlobalSchedulerState *state) { - db_disconnect(state->db); - state->local_schedulers.clear(); - GlobalSchedulerPolicyState_free(state->policy_state); - /* Delete the plasma to local scheduler association map. */ - state->plasma_local_scheduler_map.clear(); - - /* Delete the local scheduler to plasma association map. */ - state->local_scheduler_plasma_map.clear(); - - /* Free the scheduler object info table. */ - state->scheduler_object_info_table.clear(); - /* Free the array of unschedulable tasks. */ - int64_t num_pending_tasks = state->pending_tasks.size(); - if (num_pending_tasks > 0) { - RAY_LOG(WARNING) << "There are " << num_pending_tasks - << " remaining tasks in the pending tasks array."; - } - for (int i = 0; i < num_pending_tasks; ++i) { - Task *pending_task = state->pending_tasks[i]; - Task_free(pending_task); - } - state->pending_tasks.clear(); - - /* Destroy the event loop. */ - destroy_outstanding_callbacks(state->loop); - event_loop_destroy(state->loop); - state->loop = NULL; - - /* Free the global scheduler state. */ - delete state; -} - -/* We need this code so we can clean up when we get a SIGTERM signal. */ - -GlobalSchedulerState *g_state; - -void signal_handler(int signal) { - if (signal == SIGTERM) { - GlobalSchedulerState_free(g_state); - exit(0); - } -} - -/* End of the cleanup code. */ - -void process_task_waiting(Task *waiting_task, void *user_context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "Task waiting callback is called."; - bool successfully_assigned = - handle_task_waiting(state, state->policy_state, waiting_task); - /* If the task was not successfully submitted to a local scheduler, add the - * task to the array of pending tasks. The global scheduler will periodically - * resubmit the tasks in this array. */ - if (!successfully_assigned) { - Task *task_copy = Task_copy(waiting_task); - state->pending_tasks.push_back(task_copy); - } -} - -void add_local_scheduler(GlobalSchedulerState *state, - DBClientID db_client_id, - const char *manager_address) { - /* Add plasma_manager ip:port -> local_scheduler_db_client_id association to - * state. */ - state->plasma_local_scheduler_map[std::string(manager_address)] = - db_client_id; - - /* Add local_scheduler_db_client_id -> plasma_manager ip:port association to - * state. */ - state->local_scheduler_plasma_map[db_client_id] = - std::string(manager_address); - - /* Add new local scheduler to the state. */ - LocalScheduler &local_scheduler = state->local_schedulers[db_client_id]; - local_scheduler.id = db_client_id; - local_scheduler.num_heartbeats_missed = 0; - local_scheduler.num_tasks_sent = 0; - local_scheduler.num_recent_tasks_sent = 0; - local_scheduler.info.task_queue_length = 0; - local_scheduler.info.available_workers = 0; - - /* Allow the scheduling algorithm to process this event. */ - handle_new_local_scheduler(state, state->policy_state, db_client_id); -} - -std::unordered_map::iterator remove_local_scheduler( - GlobalSchedulerState *state, - std::unordered_map::iterator it) { - RAY_CHECK(it != state->local_schedulers.end()); - DBClientID local_scheduler_id = it->first; - it = state->local_schedulers.erase(it); - - /* Remove the local scheduler from the mappings. This code only makes sense if - * there is a one-to-one mapping between local schedulers and plasma managers. - */ - std::string manager_address = - state->local_scheduler_plasma_map[local_scheduler_id]; - state->local_scheduler_plasma_map.erase(local_scheduler_id); - state->plasma_local_scheduler_map.erase(manager_address); - - handle_local_scheduler_removed(state, state->policy_state, - local_scheduler_id); - return it; -} - -/** - * Process a notification about a new DB client connecting to Redis. - * - * @param manager_address An ip:port pair for the plasma manager associated with - * this db client. - * @return Void. - */ -void process_new_db_client(DBClient *db_client, void *user_context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "db client table callback for db client = " - << db_client->id; - if (strncmp(db_client->client_type.c_str(), "local_scheduler", - strlen("local_scheduler")) == 0) { - bool local_scheduler_present = - (state->local_schedulers.find(db_client->id) != - state->local_schedulers.end()); - if (db_client->is_alive) { - /* This is a notification for an insert. We may receive duplicate - * notifications since we read the entire table before processing - * notifications. Filter out local schedulers that we already added. */ - if (!local_scheduler_present) { - add_local_scheduler(state, db_client->id, - db_client->manager_address.c_str()); - } - } else { - if (local_scheduler_present) { - remove_local_scheduler(state, - state->local_schedulers.find(db_client->id)); - } - } - } -} - -/** - * Process notification about the new object information. - * - * @param object_id ID of the object that the notification is about. - * @param data_size The object size. - * @param manager_count The number of locations for this object. - * @param manager_ids The vector of Plasma Manager client IDs. - * @param user_context The user context. - * @return Void. - */ -void object_table_subscribe_callback(ObjectID object_id, - int64_t data_size, - const std::vector &manager_ids, - void *user_context) { - /* Extract global scheduler state from the callback context. */ - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - RAY_LOG(DEBUG) << "object table subscribe callback for OBJECT = " - << object_id; - - const std::vector managers = - db_client_table_get_ip_addresses(state->db, manager_ids); - RAY_LOG(DEBUG) << "\tManagers<" << managers.size() << ">:"; - for (size_t i = 0; i < managers.size(); i++) { - RAY_LOG(DEBUG) << "\t\t" << managers[i]; - } - - if (state->scheduler_object_info_table.find(object_id) == - state->scheduler_object_info_table.end()) { - /* Construct a new object info hash table entry. */ - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[object_id]; - obj_info_entry.data_size = data_size; - - RAY_LOG(DEBUG) << "New object added to object_info_table with id = " - << object_id; - RAY_LOG(DEBUG) << "\tmanager locations:"; - for (size_t i = 0; i < managers.size(); i++) { - RAY_LOG(DEBUG) << "\t\t" << managers[i]; - } - } - - SchedulerObjectInfo &obj_info_entry = - state->scheduler_object_info_table[object_id]; - - /* In all cases, replace the object location vector on each callback. */ - obj_info_entry.object_locations.clear(); - for (size_t i = 0; i < managers.size(); i++) { - obj_info_entry.object_locations.push_back(managers[i]); - } -} - -void local_scheduler_table_handler(DBClientID client_id, - LocalSchedulerInfo info, - void *user_context) { - /* Extract global scheduler state from the callback context. */ - GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; - ARROW_UNUSED(state); - RAY_LOG(DEBUG) << "Local scheduler heartbeat from db_client_id " << client_id; - RAY_LOG(DEBUG) << "total workers = " << info.total_num_workers - << ", task queue length = " << info.task_queue_length - << ", available workers = " << info.available_workers; - - /* Update the local scheduler info struct. */ - auto it = state->local_schedulers.find(client_id); - if (it != state->local_schedulers.end()) { - if (info.is_dead) { - /* The local scheduler is exiting. Increase the number of heartbeats - * missed to the timeout threshold. This will trigger removal of the - * local scheduler the next time the timeout handler fires. */ - it->second.num_heartbeats_missed = - RayConfig::instance().num_heartbeats_timeout(); - } else { - /* Reset the number of tasks sent since the last heartbeat. */ - LocalScheduler &local_scheduler = it->second; - local_scheduler.num_heartbeats_missed = 0; - local_scheduler.num_recent_tasks_sent = 0; - local_scheduler.info = info; - } - } else { - RAY_LOG(WARNING) << "client_id didn't match any cached local scheduler " - << "entries"; - } -} - -int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) context; - /* Loop over the pending tasks in reverse order and resubmit them. */ - auto it = state->pending_tasks.end(); - while (it != state->pending_tasks.begin()) { - it--; - Task *pending_task = *it; - /* Pretend that the task has been resubmitted. */ - bool successfully_assigned = - handle_task_waiting(state, state->policy_state, pending_task); - if (successfully_assigned) { - /* The task was successfully assigned, so remove it from this list and - * free it. This uses the fact that pending_tasks is a vector and so erase - * returns an iterator to the next element in the vector. */ - it = state->pending_tasks.erase(it); - Task_free(pending_task); - } - } - - return GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS; -} - -int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) { - GlobalSchedulerState *state = (GlobalSchedulerState *) context; - /* Check for local schedulers that have missed a number of heartbeats. If any - * local schedulers have died, notify others so that the state can be cleaned - * up. */ - /* TODO(swang): If the local scheduler hasn't actually died, then it should - * clean up its state and exit upon receiving this notification. */ - auto it = state->local_schedulers.begin(); - while (it != state->local_schedulers.end()) { - if (it->second.num_heartbeats_missed >= - RayConfig::instance().num_heartbeats_timeout()) { - RAY_LOG(WARNING) << "Missed too many heartbeats from local scheduler, " - << "marking as dead."; - /* Notify others by updating the global state. */ - db_client_table_remove(state->db, it->second.id, NULL, NULL, NULL); - /* Remove the scheduler from the local state. The call to - * remove_local_scheduler modifies the container in place and returns the - * next iterator. */ - it = remove_local_scheduler(state, it); - } else { - it->second.num_heartbeats_missed += 1; - it++; - } - } - - /* Reset the timer. */ - return RayConfig::instance().heartbeat_timeout_milliseconds(); -} - -void start_server(const char *node_ip_address, - const char *redis_primary_addr, - int redis_primary_port) { - event_loop *loop = event_loop_create(); - g_state = GlobalSchedulerState_init(loop, node_ip_address, redis_primary_addr, - redis_primary_port); - /* TODO(rkn): subscribe to notifications from the object table. */ - /* Subscribe to notifications about new local schedulers. TODO(rkn): this - * needs to also get all of the clients that registered with the database - * before this call to subscribe. */ - db_client_table_subscribe(g_state->db, process_new_db_client, - (void *) g_state, NULL, NULL, NULL); - /* Subscribe to notifications about waiting tasks. If a local scheduler - * submits tasks to the global scheduler before the global scheduler - * successfully subscribes, then the local scheduler that submitted the tasks - * will retry. */ - task_table_subscribe(g_state->db, UniqueID::nil(), TaskStatus::WAITING, - process_task_waiting, (void *) g_state, NULL, NULL, - NULL); - - object_table_subscribe_to_notifications(g_state->db, true, - object_table_subscribe_callback, - g_state, NULL, NULL, NULL); - /* Subscribe to notifications from local schedulers. These notifications serve - * as heartbeats and contain informaion about the load on the local - * schedulers. */ - local_scheduler_table_subscribe(g_state->db, local_scheduler_table_handler, - g_state, NULL); - /* Start a timer that periodically checks if there are queued tasks that can - * be scheduled. Currently this is only used to handle the special case in - * which a task is waiting and no node meets its static resource requirements. - * If a new node joins the cluster that does have enough resources, then this - * timer should notice and schedule the task. */ - event_loop_add_timer(loop, GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS, - task_cleanup_handler, g_state); - event_loop_add_timer(loop, - RayConfig::instance().heartbeat_timeout_milliseconds(), - heartbeat_timeout_handler, g_state); - /* Start the event loop. */ - event_loop_run(loop); -} - -int main(int argc, char *argv[]) { - InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], - ray::RayLogLevel::INFO, /*log_dir=*/""); - ray::RayLog::InstallFailureSignalHandler(); - signal(SIGTERM, signal_handler); - /* IP address and port of the primary redis instance. */ - char *redis_primary_addr_port = NULL; - /* The IP address of the node that this global scheduler is running on. */ - char *node_ip_address = NULL; - int c; - while ((c = getopt(argc, argv, "h:r:")) != -1) { - switch (c) { - case 'r': - redis_primary_addr_port = optarg; - break; - case 'h': - node_ip_address = optarg; - break; - default: - RAY_LOG(FATAL) << "unknown option " << c; - } - } - - char redis_primary_addr[16]; - int redis_primary_port = -1; - if (!redis_primary_addr_port || - parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, - &redis_primary_port) == -1) { - RAY_LOG(FATAL) << "specify the primary redis address like 127.0.0.1:6379 " - << "with the -r switch"; - } - if (!node_ip_address) { - RAY_LOG(FATAL) << "specify the node IP address with the -h switch"; - } - start_server(node_ip_address, redis_primary_addr, redis_primary_port); -} diff --git a/src/global_scheduler/global_scheduler.h b/src/global_scheduler/global_scheduler.h deleted file mode 100644 index e1610c555088..000000000000 --- a/src/global_scheduler/global_scheduler.h +++ /dev/null @@ -1,94 +0,0 @@ -#ifndef GLOBAL_SCHEDULER_H -#define GLOBAL_SCHEDULER_H - -#include "task.h" - -#include - -#include "ray/gcs/client.h" -#include "state/db.h" -#include "state/local_scheduler_table.h" - -/* The frequency with which the global scheduler checks if there are any tasks - * that haven't been scheduled yet. */ -#define GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS 100 - -/** Contains all information that is associated with a local scheduler. */ -typedef struct { - /** The ID of the local scheduler in Redis. */ - DBClientID id; - /** The number of heartbeat intervals that have passed since we last heard - * from this local scheduler. */ - int64_t num_heartbeats_missed; - /** The number of tasks sent from the global scheduler to this local - * scheduler. */ - int64_t num_tasks_sent; - /** The number of tasks sent from the global scheduler to this local scheduler - * since the last heartbeat arrived. */ - int64_t num_recent_tasks_sent; - /** The latest information about the local scheduler capacity. This is updated - * every time a new local scheduler heartbeat arrives. */ - LocalSchedulerInfo info; -} LocalScheduler; - -typedef class GlobalSchedulerPolicyState GlobalSchedulerPolicyState; - -/** - * This defines a hash table used to cache information about different objects. - */ -typedef struct { - /** The size in bytes of the object. */ - int64_t data_size; - /** A vector of object locations for this object. */ - std::vector object_locations; -} SchedulerObjectInfo; - -/** - * Global scheduler state structure. - */ -typedef struct { - /** The global scheduler event loop. */ - event_loop *loop; - /** The global state store database. */ - DBHandle *db; - /** A hash table mapping local scheduler ID to the local schedulers that are - * connected to Redis. */ - std::unordered_map local_schedulers; - /** The state managed by the scheduling policy. */ - GlobalSchedulerPolicyState *policy_state; - /** The plasma_manager ip:port -> local_scheduler_db_client_id association. */ - std::unordered_map plasma_local_scheduler_map; - /** The local_scheduler_db_client_id -> plasma_manager ip:port association. */ - std::unordered_map local_scheduler_plasma_map; - /** Objects cached by this global scheduler instance. */ - std::unordered_map scheduler_object_info_table; - /** An array of tasks that haven't been scheduled yet. */ - std::vector pending_tasks; -} GlobalSchedulerState; - -/** - * This is a helper method to look up the local scheduler struct that - * corresponds to a particular local_scheduler_id. - * - * @param state The state of the global scheduler. - * @param The local_scheduler_id of the local scheduler. - * @return The corresponding local scheduler struct. If the global scheduler is - * not aware of the local scheduler, then this will be NULL. - */ -LocalScheduler *get_local_scheduler(GlobalSchedulerState *state, - DBClientID local_scheduler_id); - -/** - * Assign the given task to the local scheduler, update Redis and scheduler data - * structures. - * - * @param state Global scheduler state. - * @param task Task to be assigned to the local scheduler. - * @param local_scheduler_id DB client ID for the local scheduler. - * @return Void. - */ -void assign_task_to_local_scheduler(GlobalSchedulerState *state, - Task *task, - DBClientID local_scheduler_id); - -#endif /* GLOBAL_SCHEDULER_H */ diff --git a/src/global_scheduler/global_scheduler_algorithm.cc b/src/global_scheduler/global_scheduler_algorithm.cc deleted file mode 100644 index 7ca1b86be914..000000000000 --- a/src/global_scheduler/global_scheduler_algorithm.cc +++ /dev/null @@ -1,257 +0,0 @@ -#include - -#include "task.h" -#include "state/task_table.h" - -#include "global_scheduler_algorithm.h" - -GlobalSchedulerPolicyState *GlobalSchedulerPolicyState_init(void) { - GlobalSchedulerPolicyState *policy_state = new GlobalSchedulerPolicyState(); - return policy_state; -} - -void GlobalSchedulerPolicyState_free(GlobalSchedulerPolicyState *policy_state) { - delete policy_state; -} - -/** - * Checks if the given local scheduler satisfies the task's hard constraints. - * - * @param scheduler Local scheduler. - * @param spec Task specification. - * @return True if all tasks's resource constraints are satisfied. False - * otherwise. - */ -bool constraints_satisfied_hard(const LocalScheduler *scheduler, - const TaskSpec *spec) { - if (scheduler->info.static_resources.count("CPU") == 1 && - scheduler->info.static_resources.at("CPU") == 0) { - // Don't give tasks to local schedulers that have 0 CPUs. This can be an - // issue for actor creation tasks that require 0 CPUs (but the subsequent - // actor methods require some CPUs). - return false; - } - - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Continue on if the task doesn't actually require this resource. - if (resource_quantity == 0) { - continue; - } - - // Check if the local scheduler has this resource. - if (scheduler->info.static_resources.count(resource_name) == 0) { - return false; - } - - // Check if the local scheduler has enough of the resource. - if (scheduler->info.static_resources.at(resource_name) < - resource_quantity) { - return false; - } - } - return true; -} - -int64_t locally_available_data_size(const GlobalSchedulerState *state, - DBClientID local_scheduler_id, - TaskSpec *task_spec) { - /* This function will compute the total size of all the object dependencies - * for the given task that are already locally available to the specified - * local scheduler. */ - int64_t task_data_size = 0; - - RAY_CHECK(state->local_scheduler_plasma_map.count(local_scheduler_id) == 1); - - const std::string &plasma_manager = - state->local_scheduler_plasma_map.at(local_scheduler_id); - - /* TODO(rkn): Note that if the same object ID appears as multiple arguments, - * then it will be overcounted. */ - for (int64_t i = 0; i < TaskSpec_num_args(task_spec); ++i) { - int count = TaskSpec_arg_id_count(task_spec, i); - for (int j = 0; j < count; ++j) { - ObjectID object_id = TaskSpec_arg_id(task_spec, i, j); - - if (state->scheduler_object_info_table.count(object_id) == 0) { - /* If this global scheduler is not aware of this object ID, then ignore - * it. */ - continue; - } - - const SchedulerObjectInfo &object_size_info = - state->scheduler_object_info_table.at(object_id); - - if (std::find(object_size_info.object_locations.begin(), - object_size_info.object_locations.end(), plasma_manager) == - object_size_info.object_locations.end()) { - /* This local scheduler does not have access to this object, so don't - * count this object. */ - continue; - } - - /* Look at the size of the object. */ - int64_t object_size = object_size_info.data_size; - if (object_size == -1) { - /* This means that this global scheduler does not know the object size - * yet, so assume that the object is one megabyte. TODO(rkn): Maybe we - * should instead use the average object size. */ - object_size = 1000000; - } - - /* If we get here, then this local scheduler has access to this object, so - * count the contribution of this object. */ - task_data_size += object_size; - } - } - - return task_data_size; -} - -double calculate_cost_pending(const GlobalSchedulerState *state, - const LocalScheduler *scheduler, - TaskSpec *task_spec) { - /* Calculate how much data is already present on this machine. TODO(rkn): Note - * that this information is not being used yet. Fix this. */ - locally_available_data_size(state, scheduler->id, task_spec); - /* TODO(rkn): This logic does not load balance properly when the different - * machines have different sizes. Fix this. */ - double cost_pending = scheduler->num_recent_tasks_sent + - scheduler->info.task_queue_length - - scheduler->info.available_workers; - return cost_pending; -} - -bool handle_task_waiting_random(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - TaskSpec *task_spec = Task_task_execution_spec(task)->Spec(); - RAY_CHECK(task_spec != NULL) - << "task wait handler encounted a task with NULL spec"; - - std::vector feasible_nodes; - - for (const auto &it : state->local_schedulers) { - // Local scheduler map iterator yields pairs. - const LocalScheduler &local_scheduler = it.second; - if (!constraints_satisfied_hard(&local_scheduler, task_spec)) { - continue; - } - // Add this local scheduler as a candidate for random selection. - feasible_nodes.push_back(it.first); - } - - if (feasible_nodes.size() == 0) { - RAY_LOG(ERROR) << "Infeasible task. No nodes satisfy hard constraints for " - << "task = " << Task_task_id(task); - return false; - } - - // Randomly select the local scheduler. TODO(atumanov): replace with - // std::discrete_distribution. - std::uniform_int_distribution<> dis(0, feasible_nodes.size() - 1); - DBClientID local_scheduler_id = - feasible_nodes[dis(policy_state->getRandomGenerator())]; - RAY_CHECK(!local_scheduler_id.is_nil()) - << "Task is feasible, but doesn't have a local scheduler assigned."; - // A local scheduler ID was found, so assign the task. - assign_task_to_local_scheduler(state, task, local_scheduler_id); - return true; -} - -bool handle_task_waiting_cost(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - TaskSpec *task_spec = Task_task_execution_spec(task)->Spec(); - int64_t curtime = current_time_ms(); - - RAY_CHECK(task_spec != NULL) - << "task wait handler encounted a task with NULL spec"; - - // For tasks already seen by the global scheduler (spillback > 1), - // adjust scheduled task counts for the source local scheduler. - if (task->execution_spec->SpillbackCount() > 1) { - auto it = state->local_schedulers.find(task->local_scheduler_id); - // Task's previous local scheduler must be present and known. - RAY_CHECK(it != state->local_schedulers.end()); - LocalScheduler &src_local_scheduler = it->second; - src_local_scheduler.num_recent_tasks_sent -= 1; - } - - bool task_feasible = false; - - // Go through all the nodes, calculate the score for each, pick max score. - double best_local_scheduler_score = INT32_MIN; - RAY_CHECK(best_local_scheduler_score < 0) - << "We might have a floating point underflow"; - RAY_LOG(INFO) << "ct[" << curtime << "] task from " - << task->local_scheduler_id << " spillback " - << task->execution_spec->SpillbackCount(); - - // The best node to send this task. - DBClientID best_local_scheduler_id = DBClientID::nil(); - - for (auto it = state->local_schedulers.begin(); - it != state->local_schedulers.end(); it++) { - // For each local scheduler, calculate its score. Check hard constraints - // first. - LocalScheduler *scheduler = &(it->second); - if (!constraints_satisfied_hard(scheduler, task_spec)) { - continue; - } - // Skip the local scheduler the task came from. - if (task->local_scheduler_id == scheduler->id) { - continue; - } - task_feasible = true; - // This node satisfies the hard capacity constraint. Calculate its score. - double score = -1 * calculate_cost_pending(state, scheduler, task_spec); - RAY_LOG(INFO) << "ct[" << curtime << "][" << scheduler->id << "][q" - << scheduler->info.task_queue_length << "][w" - << scheduler->info.available_workers << "]: score " << score - << " bestscore " << best_local_scheduler_score; - if (score >= best_local_scheduler_score) { - best_local_scheduler_score = score; - best_local_scheduler_id = scheduler->id; - } - } - - if (!task_feasible) { - RAY_LOG(ERROR) << "Infeasible task. No nodes satisfy hard constraints for " - << "task = " << Task_task_id(task); - // TODO(atumanov): propagate this error to the task's driver and/or - // cache the task in case new local schedulers satisfy it in the future. - return false; - } - RAY_CHECK(!best_local_scheduler_id.is_nil()) - << "Task is feasible, but doesn't have a local scheduler assigned."; - // A local scheduler ID was found, so assign the task. - assign_task_to_local_scheduler(state, task, best_local_scheduler_id); - return true; -} - -bool handle_task_waiting(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task) { - return handle_task_waiting_random(state, policy_state, task); -} - -void handle_object_available(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - ObjectID object_id) { - /* Do nothing for now. */ -} - -void handle_new_local_scheduler(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id) { - /* Do nothing for now. */ -} - -void handle_local_scheduler_removed(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id) { - /* Do nothing for now. */ -} diff --git a/src/global_scheduler/global_scheduler_algorithm.h b/src/global_scheduler/global_scheduler_algorithm.h deleted file mode 100644 index 69be67d97477..000000000000 --- a/src/global_scheduler/global_scheduler_algorithm.h +++ /dev/null @@ -1,126 +0,0 @@ -#ifndef GLOBAL_SCHEDULER_ALGORITHM_H -#define GLOBAL_SCHEDULER_ALGORITHM_H - -#include -#include - -#include "common.h" -#include "global_scheduler.h" -#include "task.h" - -/* ==== The scheduling algorithm ==== - * - * This file contains declaration for all functions and data structures that - * need to be provided if you want to implement a new algorithm for the global - * scheduler. - * - */ - -enum class GlobalSchedulerAlgorithm { - SCHED_ALGORITHM_ROUND_ROBIN = 1, - SCHED_ALGORITHM_TRANSFER_AWARE = 2, - SCHED_ALGORITHM_MAX -}; - -/// The class encapsulating state managed by the global scheduling policy. -class GlobalSchedulerPolicyState { - public: - GlobalSchedulerPolicyState(int64_t round_robin_index) - : round_robin_index_(round_robin_index), - gen_(std::chrono::high_resolution_clock::now() - .time_since_epoch() - .count()) {} - - GlobalSchedulerPolicyState() - : round_robin_index_(0), - gen_(std::chrono::high_resolution_clock::now() - .time_since_epoch() - .count()) {} - - /// Return the policy's random number generator. - /// - /// @return The policy's random number generator. - std::mt19937_64 &getRandomGenerator() { return gen_; } - - /// Return the round robin index maintained by policy state. - /// - /// @return The round robin index. - int64_t getRoundRobinIndex() const { return round_robin_index_; } - - private: - /// The index of the next local scheduler to assign a task to. - int64_t round_robin_index_; - /// Internally maintained random number generator. - std::mt19937_64 gen_; -}; - -/** - * Create the state of the global scheduler policy. This state must be freed by - * the caller. - * - * @return The state of the scheduling policy. - */ -GlobalSchedulerPolicyState *GlobalSchedulerPolicyState_init(void); - -/** - * Free the global scheduler policy state. - * - * @param policy_state The policy state to free. - * @return Void. - */ -void GlobalSchedulerPolicyState_free(GlobalSchedulerPolicyState *policy_state); - -/** - * Main new task handling function in the global scheduler. - * - * @param state Global scheduler state. - * @param policy_state State specific to the scheduling policy. - * @param task New task to be scheduled. - * @return True if the task was assigned to a local scheduler and false - * otherwise. - */ -bool handle_task_waiting(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - Task *task); - -/** - * Handle the fact that a new object is available. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @param object_id The ID of the object that is now available. - * @return Void. - */ -void handle_object_available(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - ObjectID object_id); - -/** - * Handle a heartbeat message from a local scheduler. TODO(rkn): this is a - * placeholder for now. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @return Void. - */ -void handle_local_scheduler_heartbeat(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state); - -/** - * Handle the presence of a new local scheduler. Currently, this just adds the - * local scheduler to a queue of local schedulers. - * - * @param state The global scheduler state. - * @param policy_state The state managed by the scheduling policy. - * @param The db client ID of the new local scheduler. - * @return Void. - */ -void handle_new_local_scheduler(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id); - -void handle_local_scheduler_removed(GlobalSchedulerState *state, - GlobalSchedulerPolicyState *policy_state, - DBClientID db_client_id); - -#endif /* GLOBAL_SCHEDULER_ALGORITHM_H */ diff --git a/src/local_scheduler/CMakeLists.txt b/src/local_scheduler/CMakeLists.txt deleted file mode 100644 index 7033c4f2306c..000000000000 --- a/src/local_scheduler/CMakeLists.txt +++ /dev/null @@ -1,104 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(local_scheduler) - -add_definitions(-fPIC) - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${PYTHON_INCLUDE_DIRS}") - include_directories("${NUMPY_INCLUDE_DIR}") -endif() - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall") - -if(UNIX AND NOT APPLE) - link_libraries(rt) -endif() - -include_directories("${CMAKE_CURRENT_LIST_DIR}/") -include_directories("${CMAKE_CURRENT_LIST_DIR}/../") -# TODO(pcm): get rid of this: -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - include_directories("${CMAKE_CURRENT_LIST_DIR}/../plasma/") -endif() - -include_directories("${ARROW_INCLUDE_DIR}") -include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/format/") - -# Compile flatbuffers - -set(LOCAL_SCHEDULER_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/local_scheduler.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(LOCAL_SCHEDULER_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/local_scheduler_generated.h") - -add_custom_command( - OUTPUT ${LOCAL_SCHEDULER_FBS_OUTPUT_FILES} - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${LOCAL_SCHEDULER_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${FBS_DEPENDS} - COMMENT "Running flatc compiler on ${LOCAL_SCHEDULER_FBS_SRC}" - VERBATIM) - -add_custom_target(gen_local_scheduler_fbs DEPENDS ${LOCAL_SCHEDULER_FBS_OUTPUT_FILES}) - -add_dependencies(gen_local_scheduler_fbs arrow) - -add_library(local_scheduler_client STATIC local_scheduler_client.cc) - -# local_scheduler_shared.h includes ray/gcs/client.h which requires gen_gcs_fbs & gen_node_manager_fbs. -add_dependencies(local_scheduler_client common hiredis gen_local_scheduler_fbs ${COMMON_FBS_OUTPUT_FILES} gen_gcs_fbs gen_node_manager_fbs) - -add_executable(local_scheduler local_scheduler.cc local_scheduler_algorithm.cc) -add_dependencies(local_scheduler hiredis) -target_link_libraries(local_scheduler local_scheduler_client common ${HIREDIS_LIB} ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) - -add_executable(local_scheduler_tests test/local_scheduler_tests.cc local_scheduler.cc local_scheduler_algorithm.cc) -add_dependencies(local_scheduler_tests hiredis) -target_link_libraries(local_scheduler_tests local_scheduler_client common ${HIREDIS_LIB} ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) -target_compile_options(local_scheduler_tests PUBLIC "-DLOCAL_SCHEDULER_TEST") - -macro(get_local_scheduler_library LANG VAR) - set(${VAR} "local_scheduler_library_${LANG}") -endmacro() - -macro(set_local_scheduler_library LANG) - get_local_scheduler_library(${LANG} LOCAL_SCHEDULER_LIBRARY_${LANG}) - set(LOCAL_SCHEDULER_LIBRARY_LANG ${LOCAL_SCHEDULER_LIBRARY_${LANG}}) - include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/lib/${LANG}/") - - file(GLOB LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC - lib/${LANG}/*.cc - ${CMAKE_CURRENT_LIST_DIR}/../common/lib/${LANG}/*.cc) - add_library(${LOCAL_SCHEDULER_LIBRARY_LANG} SHARED - ${LOCAL_SCHEDULER_LIBRARY_${LANG}_SRC}) - - if(APPLE) - if ("${LANG}" STREQUAL "python") - SET_TARGET_PROPERTIES(${LOCAL_SCHEDULER_LIBRARY_LANG} PROPERTIES SUFFIX .so) - endif() - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} "-undefined dynamic_lookup" local_scheduler_client common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) - else(APPLE) - target_link_libraries(${LOCAL_SCHEDULER_LIBRARY_LANG} local_scheduler_client common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) - endif(APPLE) - - add_dependencies(${LOCAL_SCHEDULER_LIBRARY_LANG} gen_local_scheduler_fbs) - - install(TARGETS ${LOCAL_SCHEDULER_LIBRARY_LANG} DESTINATION ${CMAKE_SOURCE_DIR}/local_scheduler) -endmacro() - -if ("${CMAKE_RAY_LANG_PYTHON}" STREQUAL "YES") - set_local_scheduler_library("python") -endif() - -if ("${CMAKE_RAY_LANG_JAVA}" STREQUAL "YES") - add_compile_options("-I$ENV{JAVA_HOME}/include/") - if(WIN32) - add_compile_options("-I$ENV{JAVA_HOME}/include/win32") - elseif(APPLE) - add_compile_options("-I$ENV{JAVA_HOME}/include/darwin") - else() # linux - add_compile_options("-I$ENV{JAVA_HOME}/include/linux") - endif() - set_local_scheduler_library("java") -endif() diff --git a/src/local_scheduler/build/.gitkeep b/src/local_scheduler/build/.gitkeep deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs deleted file mode 100644 index a23bb28f05f3..000000000000 --- a/src/local_scheduler/format/local_scheduler.fbs +++ /dev/null @@ -1,130 +0,0 @@ -// Local scheduler protocol specification -namespace ray.local_scheduler.protocol; - -enum MessageType:int { - // Task is submitted to the local scheduler. This is sent from a worker to a - // local scheduler. - SubmitTask = 1, - // Notify the local scheduler that a task has finished. This is sent from a - // worker to a local scheduler. - TaskDone, - // Log a message to the event table. This is sent from a worker to a local - // scheduler. - EventLogMessage, - // Send an initial connection message to the local scheduler. This is sent - // from a worker or driver to a local scheduler. - RegisterClientRequest, - // Send a reply confirming the successful registration of a worker or driver. - // This is sent from the local scheduler to a worker or driver. - RegisterClientReply, - // Notify the local scheduler that this client disconnected unexpectedly. - // This is sent from a worker to a local scheduler. - DisconnectClient, - // Notify the local scheduler that this client is disconnecting gracefully. - // This is sent from a worker to a local scheduler. - IntentionalDisconnectClient, - // Get a new task from the local scheduler. This is sent from a worker to a - // local scheduler. - GetTask, - // Tell a worker to execute a task. This is sent from a local scheduler to a - // worker. - ExecuteTask, - // Reconstruct or fetch possibly lost objects. This is sent from a worker to - // a local scheduler. - ReconstructObjects, - // For a worker that was blocked on some object(s), tell the local scheduler - // that the worker is now unblocked. This is sent from a worker to a local - // scheduler. - NotifyUnblocked, - // Add a result table entry for an object put. - PutObject, - // A request to get the task frontier for an actor, called by the actor when - // saving a checkpoint. - GetActorFrontierRequest, - // The ActorFrontier response to a GetActorFrontierRequest. The local - // scheduler returns the actor's per-handle task counts and execution - // dependencies, which can later be used as the argument to SetActorFrontier - // when resuming from the checkpoint. - GetActorFrontierReply, - // A request to set the task frontier for an actor, called when resuming from - // a checkpoint. The local scheduler will update the actor's per-handle task - // counts and execution dependencies, discard any tasks that already executed - // before the checkpoint, and make any tasks on the frontier runnable by - // making their execution dependencies available. - SetActorFrontier -} - -table SubmitTaskRequest { - execution_dependencies: [string]; - task_spec: string; -} - -// This message is sent from the local scheduler to a worker. -table GetTaskReply { - // A string of bytes representing the task specification. - task_spec: string; - // The IDs of the GPUs that the worker is allowed to use for this task. - gpu_ids: [int]; -} - -table EventLogMessage { - key: string; - value: string; - timestamp: double; -} - -// This struct is used to register a new worker with the local scheduler. -// It is shipped as part of local_scheduler_connect. -table RegisterClientRequest { - // True if the client is a worker and false if the client is a driver. - is_worker: bool; - // The ID of the worker or driver. - client_id: string; - // The process ID of this worker. - worker_pid: long; - // The driver ID. This is non-nil if the client is a driver. - driver_id: string; -} - -table DisconnectClient { -} - -table ReconstructObjects { - // List of object IDs of the objects that we want to reconstruct or fetch. - object_ids: [string]; - // Do we only want to fetch the objects or also reconstruct them? - fetch_only: bool; -} - -table PutObject { - // Task ID of the task that performed the put. - task_id: string; - // Object ID of the object that is being put. - object_id: string; -} - -// The ActorFrontier is used to represent the current frontier of tasks that -// the local scheduler has marked as runnable for a particular actor. It is -// used to save the point in an actor's lifetime at which a checkpoint was -// taken, so that the same frontier of tasks can be made runnable again if the -// actor is resumed from that checkpoint. -table ActorFrontier { - // Actor ID of the actor whose frontier is described. - actor_id: string; - // A list of handle IDs, representing the callers of the actor that have - // submitted a runnable task to the local scheduler. A nil ID represents the - // creator of the actor. - handle_ids: [string]; - // A list representing the number of tasks executed so far, per handle. Each - // count in task_counters corresponds to the handle at the same in index in - // handle_ids. - task_counters: [long]; - // A list representing the execution dependency for the next runnable task, - // per handle. Each execution dependency in frontier_dependencies corresponds - // to the handle at the same in index in handle_ids. - frontier_dependencies: [string]; -} - -table GetActorFrontierRequest { - actor_id: string; -} diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc deleted file mode 100644 index d2c50c3fbb1d..000000000000 --- a/src/local_scheduler/local_scheduler.cc +++ /dev/null @@ -1,1555 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "common.h" -#include "common_protocol.h" -#include "event_loop.h" -#include "format/local_scheduler_generated.h" -#include "io.h" -#include "local_scheduler.h" -#include "local_scheduler_algorithm.h" -#include "local_scheduler_shared.h" -#include "logging.h" -#include "net.h" -#include "ray/util/util.h" -#include "state/actor_notification_table.h" -#include "state/db.h" -#include "state/db_client_table.h" -#include "state/driver_table.h" -#include "state/error_table.h" -#include "state/object_table.h" -#include "state/task_table.h" - -using MessageType = ray::local_scheduler::protocol::MessageType; - -/** - * A helper function for printing available and requested resource information. - * - * @param state Local scheduler state. - * @param spec Task specification object. - * @return Void. - */ -void print_resource_info(const LocalSchedulerState *state, - const TaskSpec *spec) { -#if RAY_COMMON_LOG_LEVEL <= RAY_COMMON_DEBUG - // Print information about available and requested resources. - std::cout << "Static Resources: " << std::endl; - for (auto const &resource_pair : state->static_resources) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - std::cout << "Dynamic Resources: " << std::endl; - for (auto const &resource_pair : state->dynamic_resources) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - if (spec) { - std::cout << "Task Required Resources: " << std::endl; - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - std::cout << " " << resource_pair.first << ": " << resource_pair.second - << std::endl; - } - } -#endif -} - -int force_kill_worker(event_loop *loop, timer_id id, void *context) { - LocalSchedulerClient *worker = (LocalSchedulerClient *) context; - kill(worker->pid, SIGKILL); - close(worker->sock); - delete worker; - return EVENT_LOOP_TIMER_DONE; -} - -void kill_worker(LocalSchedulerState *state, - LocalSchedulerClient *worker, - bool cleanup, - bool suppress_warning) { - /* Erase the local scheduler's reference to the worker. */ - auto it = std::find(state->workers.begin(), state->workers.end(), worker); - RAY_CHECK(it != state->workers.end()); - state->workers.erase(it); - - /* Make sure that we removed the worker. */ - it = std::find(state->workers.begin(), state->workers.end(), worker); - RAY_CHECK(it == state->workers.end()); - - /* Release any resources held by the worker. It's important to do this before - * calling handle_worker_removed and handle_actor_worker_disconnect because - * freeing up resources here will allow the scheduling algorithm to dispatch - * more tasks. */ - release_resources(state, worker, worker->resources_in_use); - - /* Erase the algorithm state's reference to the worker. */ - if (worker->actor_id.is_nil()) { - handle_worker_removed(state, state->algorithm_state, worker); - } else { - /* Let the scheduling algorithm process the absence of this worker. */ - handle_actor_worker_disconnect(state, state->algorithm_state, worker, - cleanup); - } - - /* Remove the client socket from the event loop so that we don't process the - * SIGPIPE when the worker is killed. */ - event_loop_remove_file(state->loop, worker->sock); - - /* If the worker has registered a process ID with us and it's a child - * process, use it to send a kill signal. */ - bool free_worker = true; - if (worker->is_child && worker->pid != 0) { - /* If worker is a driver, we should not enter this condition because - * worker->pid should be 0. */ - if (cleanup) { - /* If we're exiting the local scheduler anyway, it's okay to force kill - * the worker immediately. Wait for the process to exit. */ - kill(worker->pid, SIGKILL); - waitpid(worker->pid, NULL, 0); - close(worker->sock); - } else { - /* If we're just cleaning up a single worker, allow it some time to clean - * up its state before force killing. The client socket will be closed - * and the worker struct will be freed after the timeout. */ - kill(worker->pid, SIGTERM); - event_loop_add_timer( - state->loop, RayConfig::instance().kill_worker_timeout_milliseconds(), - force_kill_worker, (void *) worker); - free_worker = false; - } - RAY_LOG(DEBUG) << "Killed worker with pid " << worker->pid; - } - - /* If this worker is still running a task and we aren't cleaning up, push an - * error message to the driver responsible for the task. */ - if (worker->task_in_progress != NULL && !cleanup && !suppress_warning) { - TaskSpec *spec = Task_task_execution_spec(worker->task_in_progress)->Spec(); - - std::ostringstream error_message; - error_message << "The worker with ID " << worker->client_id << " died or " - << "was killed while executing the task with ID " - << TaskSpec_task_id(spec); - push_error(state->db, TaskSpec_driver_id(spec), ErrorIndex::WORKER_DIED, - error_message.str()); - } - - /* Clean up the task in progress. */ - if (worker->task_in_progress) { - /* Update the task table to reflect that the task failed to complete. */ - if (state->db != NULL) { - Task_set_state(worker->task_in_progress, TaskStatus::LOST); - task_table_update(state->db, worker->task_in_progress, NULL, NULL, NULL); - } else { - Task_free(worker->task_in_progress); - } - } - - RAY_LOG(DEBUG) << "Killed worker with pid " << worker->pid; - if (free_worker) { - /* Clean up the client socket after killing the worker so that the worker - * can't receive the SIGPIPE before exiting. */ - close(worker->sock); - delete worker; - } -} - -void LocalSchedulerState_free(LocalSchedulerState *state) { - /* Reset the SIGTERM handler to default behavior, so we try to clean up the - * local scheduler at most once. If a SIGTERM is caught afterwards, there is - * the possibility of orphan worker processes. */ - signal(SIGTERM, SIG_DFL); - /* Send a null heartbeat that tells the global scheduler that we are dead to - * avoid waiting for the heartbeat timeout. */ - if (state->db != NULL) { - local_scheduler_table_disconnect(state->db); - } - - /* Kill any child processes that didn't register as a worker yet. */ - for (auto const &worker_pid : state->child_pids) { - kill(worker_pid, SIGKILL); - waitpid(worker_pid, NULL, 0); - RAY_LOG(INFO) << "Killed worker pid " << worker_pid - << " which hadn't started yet."; - } - - /* Kill any registered workers. */ - /* TODO(swang): It's possible that the local scheduler will exit before all - * of its task table updates make it to redis. */ - while (state->workers.size() > 0) { - /* Note that kill_worker modifies the container state->workers, so it is - * important to do this loop in a way that does not use invalidated - * iterators. */ - kill_worker(state, state->workers.back(), true, false); - } - - /* Disconnect from plasma. */ - ARROW_CHECK_OK(state->plasma_conn->Disconnect()); - delete state->plasma_conn; - state->plasma_conn = NULL; - - /* Clean up the database connection. NOTE(swang): The global scheduler is - * responsible for deleting our entry from the db_client table, so do not - * delete it here. */ - if (state->db != NULL) { - DBHandle_free(state->db); - } - - /* Free the command for starting new workers. */ - if (state->config.start_worker_command != NULL) { - int i = 0; - const char *arg = state->config.start_worker_command[i]; - while (arg != NULL) { - free((void *) arg); - ++i; - arg = state->config.start_worker_command[i]; - } - free(state->config.start_worker_command); - state->config.start_worker_command = NULL; - } - - /* Free the algorithm state. */ - SchedulingAlgorithmState_free(state->algorithm_state); - state->algorithm_state = NULL; - - event_loop *loop = state->loop; - - /* Free the scheduler state. */ - delete state; - - /* Destroy the event loop. */ - destroy_outstanding_callbacks(loop); - event_loop_destroy(loop); -} - -void start_worker(LocalSchedulerState *state) { - /* We can't start a worker if we don't have the path to the worker script. */ - if (state->config.start_worker_command == NULL) { - RAY_LOG(DEBUG) << "No valid command to start worker provided. Cannot start " - << "worker."; - return; - } - /* Launch the process to create the worker. */ - pid_t pid = fork(); - if (pid != 0) { - state->child_pids.push_back(pid); - RAY_LOG(DEBUG) << "Started worker with pid " << pid; - return; - } - - /* Reset the SIGCHLD handler so that it doesn't influence the worker. */ - signal(SIGCHLD, SIG_DFL); - - std::vector command_vector; - for (int i = 0; state->config.start_worker_command[i] != NULL; i++) { - command_vector.push_back(state->config.start_worker_command[i]); - } - - /* Add a NULL pointer to the end. */ - command_vector.push_back(NULL); - - /* Try to execute the worker command. Exit if we're not successful. */ - execvp(command_vector[0], (char *const *) command_vector.data()); - - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Failed to start worker"; -} - -/** - * Parse the command to start a worker. This takes in the command string, - * splits it into tokens on the space characters, and allocates an array of the - * tokens, terminated by a NULL pointer. - * - * @param command The command string to start a worker. - * @return A pointer to an array of strings, the tokens in the command string. - * The last element is a NULL pointer. - */ -const char **parse_command(const char *command) { - /* Count the number of tokens. */ - char *command_copy = strdup(command); - const char *delimiter = " "; - char *token = NULL; - int num_args = 0; - token = strtok(command_copy, delimiter); - while (token != NULL) { - ++num_args; - token = strtok(NULL, delimiter); - } - free(command_copy); - - /* Allocate a NULL-terminated array for the tokens. */ - const char **command_args = - (const char **) malloc((num_args + 1) * sizeof(const char *)); - command_args[num_args] = NULL; - - /* Fill in the token array. */ - command_copy = strdup(command); - token = strtok(command_copy, delimiter); - int i = 0; - while (token != NULL) { - command_args[i] = strdup(token); - ++i; - token = strtok(NULL, delimiter); - } - free(command_copy); - - RAY_CHECK(num_args == i); - return command_args; -} - -LocalSchedulerState *LocalSchedulerState_init( - const char *node_ip_address, - event_loop *loop, - const char *redis_primary_addr, - int redis_primary_port, - const char *local_scheduler_socket_name, - const char *plasma_store_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_conf, - const char *start_worker_command, - int num_workers) { - LocalSchedulerState *state = new LocalSchedulerState(); - /* Set the configuration struct for the local scheduler. */ - if (start_worker_command != NULL) { - state->config.start_worker_command = parse_command(start_worker_command); - } else { - state->config.start_worker_command = NULL; - } - if (start_worker_command == NULL) { - RAY_LOG(WARNING) << "No valid command to start a worker provided, local " - << "scheduler will not start any workers."; - } - state->config.global_scheduler_exists = global_scheduler_exists; - - state->loop = loop; - - /* Connect to Redis if a Redis address is provided. */ - if (redis_primary_addr != NULL) { - /* Construct db_connect_args */ - std::vector db_connect_args; - db_connect_args.push_back("local_scheduler_socket_name"); - db_connect_args.push_back(local_scheduler_socket_name); - for (auto const &resource_pair : static_resource_conf) { - // TODO(rkn): This could cause issues if a resource name collides with - // another field name "manager_address". - db_connect_args.push_back(resource_pair.first); - db_connect_args.push_back(std::to_string(resource_pair.second)); - } - - if (plasma_manager_address != NULL) { - db_connect_args.push_back("manager_address"); - db_connect_args.push_back(plasma_manager_address); - } - - state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, - "local_scheduler", node_ip_address, db_connect_args); - db_attach(state->db, loop, false); - } else { - state->db = NULL; - } - /* Connect to Plasma. This method will retry if Plasma hasn't started yet. */ - state->plasma_conn = new plasma::PlasmaClient(); - if (plasma_manager_socket_name != NULL) { - ARROW_CHECK_OK(state->plasma_conn->Connect( - plasma_store_socket_name, plasma_manager_socket_name, - plasma::kPlasmaDefaultReleaseDelay)); - } else { - ARROW_CHECK_OK(state->plasma_conn->Connect( - plasma_store_socket_name, "", plasma::kPlasmaDefaultReleaseDelay)); - } - /* Subscribe to notifications about sealed objects. */ - int plasma_fd; - ARROW_CHECK_OK(state->plasma_conn->Subscribe(&plasma_fd)); - /* Add the callback that processes the notification to the event loop. */ - event_loop_add_file(loop, plasma_fd, EVENT_LOOP_READ, - process_plasma_notification, state); - /* Add scheduler state. */ - state->algorithm_state = SchedulingAlgorithmState_init(); - - /* Initialize resource vectors. */ - state->static_resources = static_resource_conf; - state->dynamic_resources = static_resource_conf; - /* Initialize available GPUs. */ - if (state->static_resources.count("GPU") == 1) { - for (int i = 0; i < state->static_resources["GPU"]; ++i) { - state->available_gpus.push_back(i); - } - } - /* Print some debug information about resource configuration. */ - print_resource_info(state, NULL); - - /* Start the initial set of workers. */ - for (int i = 0; i < num_workers; ++i) { - start_worker(state); - } - - /* Initialize the time at which the previous heartbeat was sent. */ - state->previous_heartbeat_time = current_time_ms(); - - return state; -} - -/* TODO(atumanov): vectorize resource counts on input. */ -bool check_dynamic_resources( - LocalSchedulerState *state, - const std::unordered_map &resources) { - for (auto const &resource_pair : resources) { - std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - if (state->dynamic_resources[resource_name] < resource_quantity) { - return false; - } - } - return true; -} - -void resource_sanity_checks(LocalSchedulerState *state, - LocalSchedulerClient *worker) { - // Check the resources in use by the worker. - for (auto const &resource_pair : worker->resources_in_use) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - RAY_CHECK(state->dynamic_resources[resource_name] <= - state->static_resources[resource_name]); - if (resource_name != std::string("CPU")) { - RAY_CHECK(state->dynamic_resources[resource_name] >= 0); - } - - RAY_CHECK(resource_quantity >= 0); - RAY_CHECK(resource_quantity <= state->static_resources[resource_name]); - } -} - -/* TODO(atumanov): just pass the required resource vector of doubles. */ -void acquire_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources) { - // Loop over each required resource type and acquire the appropriate quantity. - for (auto const &resource_pair : resources) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Do some special handling for GPU resources. - if (resource_name == std::string("GPU")) { - if (resource_quantity != 0) { - // Make sure that the worker isn't using any GPUs already. - RAY_CHECK(worker->gpus_in_use.size() == 0); - RAY_CHECK(state->available_gpus.size() >= resource_quantity); - // Reserve GPUs for the worker. - for (int i = 0; i < resource_quantity; i++) { - worker->gpus_in_use.push_back(state->available_gpus.back()); - state->available_gpus.pop_back(); - } - } - } - - // Do bookkeeping for general resource types. - if (resource_name != std::string("CPU")) { - RAY_CHECK(state->dynamic_resources[resource_name] >= resource_quantity); - } - state->dynamic_resources[resource_name] -= resource_quantity; - worker->resources_in_use[resource_name] += resource_quantity; - } - - // Do some sanity checks. - resource_sanity_checks(state, worker); -} - -void release_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources) { - for (auto const &resource_pair : resources) { - const std::string resource_name = resource_pair.first; - double resource_quantity = resource_pair.second; - - // Do some special handling for GPU resources. - if (resource_name == std::string("GPU")) { - if (resource_quantity != 0) { - RAY_CHECK(resource_quantity == worker->gpus_in_use.size()); - // Move the GPU IDs the worker was using back to the local scheduler. - for (auto const &gpu_id : worker->gpus_in_use) { - state->available_gpus.push_back(gpu_id); - } - worker->gpus_in_use.clear(); - } - } - - // Do bookkeeping for general resources types. - state->dynamic_resources[resource_name] += resource_quantity; - worker->resources_in_use[resource_name] -= resource_quantity; - } - - // Do some sanity checks. - resource_sanity_checks(state, worker); -} - -bool is_driver_alive(LocalSchedulerState *state, WorkerID driver_id) { - return state->removed_drivers.count(driver_id) == 0; -} - -void assign_task_to_worker(LocalSchedulerState *state, - TaskExecutionSpec &execution_spec, - LocalSchedulerClient *worker) { - int64_t task_spec_size = execution_spec.SpecSize(); - TaskSpec *spec = execution_spec.Spec(); - // Acquire the necessary resources for running this task. - const std::unordered_map required_resources = - TaskSpec_get_required_resources(spec); - acquire_resources(state, worker, required_resources); - // Check that actor tasks don't have non-CPU requirements. Any necessary - // non-CPU resources (in particular, GPUs) should already have been acquired - // by the actor worker. - if (!worker->actor_id.is_nil()) { - RAY_CHECK(required_resources.size() == 1); - RAY_CHECK(required_resources.count("CPU") == 1); - } - - RAY_CHECK(worker->actor_id == TaskSpec_actor_id(spec)); - /* Make sure the driver for this task is still alive. */ - WorkerID driver_id = TaskSpec_driver_id(spec); - RAY_CHECK(is_driver_alive(state, driver_id)); - - /* Construct a flatbuffer object to send to the worker. */ - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateGetTaskReply( - fbb, fbb.CreateString((char *) spec, task_spec_size), - fbb.CreateVector(worker->gpus_in_use)); - fbb.Finish(message); - - if (write_message(worker->sock, - static_cast(MessageType::ExecuteTask), - fbb.GetSize(), (uint8_t *) fbb.GetBufferPointer()) < 0) { - if (errno == EPIPE || errno == EBADF) { - /* Something went wrong, so kill the worker. */ - kill_worker(state, worker, false, false); - RAY_LOG(WARNING) << "Failed to give task to worker on fd " << worker->sock - << ". The client may have hung up."; - } else { - RAY_LOG(FATAL) << "Failed to give task to client on fd " << worker->sock; - } - } - - Task *task = - Task_alloc(execution_spec, TaskStatus::RUNNING, - state->db ? get_db_client_id(state->db) : DBClientID::nil()); - /* Record which task this worker is executing. This will be freed in - * process_message when the worker sends a GetTask message to the local - * scheduler. */ - worker->task_in_progress = Task_copy(task); - /* Update the global task table. */ - if (state->db != NULL) { - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - Task_free(task); - } -} - -// This is used to allow task_table_update to fail. -void allow_task_table_update_failure(UniqueID id, - void *user_context, - void *user_data) {} - -void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker) { - if (worker->task_in_progress != NULL) { - TaskSpec *spec = Task_task_execution_spec(worker->task_in_progress)->Spec(); - // Return dynamic resources back for the task in progress. - if (TaskSpec_is_actor_creation_task(spec)) { - // Resources required by the actor creation task are acquired for the - // actor's lifetime, so don't return anything here. TODO(rkn): Should the - // actor creation task require 1 CPU in addition to any resources acquired - // for the lifetime of the actor? If not, then the local scheduler may - // schedule an arbitrary number of actor creation tasks concurrently (if - // they don't acquire any resources for their entire lifetime). In - // practice this will usually be rate-limited by the rate at which we can - // create new workers. - - ActorID actor_creation_id = TaskSpec_actor_creation_id(spec); - WorkerID driver_id = TaskSpec_driver_id(spec); - - // The driver must be alive because if the driver had been removed, then - // this worker would have been killed (because it was executing a task for - // the driver). - RAY_CHECK(is_driver_alive(state, driver_id)); - - // Update the worker struct with this actor ID. - RAY_CHECK(worker->actor_id.is_nil()); - worker->actor_id = actor_creation_id; - // Extract the initial execution dependency from the actor creation task. - RAY_CHECK(TaskSpec_num_returns(spec) == 1); - ObjectID initial_execution_dependency = TaskSpec_return(spec, 0); - // Let the scheduling algorithm process the presence of this new worker. - handle_convert_worker_to_actor(state, state->algorithm_state, - actor_creation_id, - initial_execution_dependency, worker); - // Publish the actor creation notification. The corresponding callback - // handle_actor_creation_callback will update state->actor_mapping. - publish_actor_creation_notification( - state->db, actor_creation_id, driver_id, get_db_client_id(state->db)); - } else if (worker->actor_id.is_nil()) { - // Return dynamic resources back for the task in progress. - RAY_CHECK(worker->resources_in_use["CPU"] == - TaskSpec_get_required_resource(spec, "CPU")); - // Return GPU resources. - RAY_CHECK(worker->gpus_in_use.size() == - TaskSpec_get_required_resource(spec, "GPU")); - release_resources(state, worker, worker->resources_in_use); - } else { - // Actor tasks should only specify CPU requirements. - RAY_CHECK(0 == TaskSpec_get_required_resource(spec, "GPU")); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - release_resources(state, worker, cpu_resources); - } - /* If we're connected to Redis, update tables. */ - if (state->db != NULL) { - /* Update control state tables. */ - TaskStatus task_state = TaskStatus::DONE; - Task_set_state(worker->task_in_progress, task_state); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = allow_task_table_update_failure, - }; - - // We allow this call to fail in case the driver has been removed and the - // task table entries have already been cleaned up by the monitor. - task_table_update(state->db, worker->task_in_progress, &retryInfo, NULL, - NULL); - } else { - Task_free(worker->task_in_progress); - } - /* The call to task_table_update takes ownership of the - * task_in_progress, so we set the pointer to NULL so it is not used. */ - worker->task_in_progress = NULL; - } -} - -void process_plasma_notification(event_loop *loop, - int client_sock, - void *context, - int events) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - /* Read the notification from Plasma. */ - uint8_t *notification = read_message_async(loop, client_sock); - if (!notification) { - /* The store has closed the socket. */ - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Lost connection to the plasma store, local scheduler is " - << "exiting!"; - } - auto object_info = flatbuffers::GetRoot(notification); - ObjectID object_id = from_flatbuf(*object_info->object_id()); - if (object_info->is_deletion()) { - handle_object_removed(state, object_id); - } else { - handle_object_available(state, state->algorithm_state, object_id); - } - free(notification); -} - -void reconstruct_task_update_callback(Task *task, - void *user_context, - bool updated) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - if (!updated) { - /* The test-and-set failed. The task is either: (1) not finished yet, (2) - * lost, but not yet updated, or (3) already being reconstructed. */ - DBClientID current_local_scheduler_id = Task_local_scheduler(task); - if (!current_local_scheduler_id.is_nil()) { - DBClient current_local_scheduler = - db_client_table_cache_get(state->db, current_local_scheduler_id); - if (!current_local_scheduler.is_alive) { - /* (2) The current local scheduler for the task is dead. The task is - * lost, but the task table hasn't received the update yet. Retry the - * test-and-set. */ - task_table_test_and_update(state->db, Task_task_id(task), - current_local_scheduler_id, Task_state(task), - TaskStatus::RECONSTRUCTING, NULL, - reconstruct_task_update_callback, state); - } - } - /* The test-and-set failed, so it is not safe to resubmit the task for - * execution. Suppress the request. */ - return; - } - - /* Otherwise, the test-and-set succeeded, so resubmit the task for execution - * to ensure that reconstruction will happen. */ - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - if (TaskSpec_actor_id(spec).is_nil()) { - handle_task_submitted(state, state->algorithm_state, *execution_spec); - } else { - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - } - - /* Recursively reconstruct the task's inputs, if necessary. */ - int64_t num_dependencies = execution_spec->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = execution_spec->DependencyIdCount(i); - for (int64_t j = 0; j < count; ++j) { - ObjectID dependency_id = execution_spec->DependencyId(i, j); - reconstruct_object(state, dependency_id); - } - } -} - -void reconstruct_put_task_update_callback(Task *task, - void *user_context, - bool updated) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - if (!updated) { - /* The test-and-set failed. The task is either: (1) not finished yet, (2) - * lost, but not yet updated, or (3) already being reconstructed. */ - DBClientID current_local_scheduler_id = Task_local_scheduler(task); - if (!current_local_scheduler_id.is_nil()) { - DBClient current_local_scheduler = - db_client_table_cache_get(state->db, current_local_scheduler_id); - if (!current_local_scheduler.is_alive) { - /* (2) The current local scheduler for the task is dead. The task is - * lost, but the task table hasn't received the update yet. Retry the - * test-and-set. */ - task_table_test_and_update(state->db, Task_task_id(task), - current_local_scheduler_id, Task_state(task), - TaskStatus::RECONSTRUCTING, NULL, - reconstruct_put_task_update_callback, state); - } else if (Task_state(task) == TaskStatus::RUNNING) { - /* (1) The task is still executing on a live node. The object created - * by `ray.put` was not able to be reconstructed, and the workload will - * likely hang. Push an error to the appropriate driver. */ - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - - std::ostringstream error_message; - error_message << "The task with ID " << TaskSpec_task_id(spec) - << " is still executing and so the object created by " - << "ray.put could not be reconstructed."; - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::PUT_RECONSTRUCTION, error_message.str()); - } - } else { - /* (1) The task is still executing and it is the driver task. We cannot - * restart the driver task, so the workload will hang. Push an error to - * the appropriate driver. */ - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - - std::ostringstream error_message; - error_message << "The task with ID " << TaskSpec_task_id(spec) - << " is a driver task and so the object created by ray.put " - << "could not be reconstructed."; - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::PUT_RECONSTRUCTION, error_message.str()); - } - } else { - /* The update to TaskStatus::RECONSTRUCTING succeeded, so continue with - * reconstruction as usual. */ - reconstruct_task_update_callback(task, user_context, updated); - } -} - -void reconstruct_evicted_result_lookup_callback(ObjectID reconstruct_object_id, - TaskID task_id, - bool is_put, - void *user_context) { - RAY_CHECK(!task_id.is_nil()) - << "No task information found for object during reconstruction"; - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - - task_table_test_and_update_callback done_callback; - if (is_put) { - /* If the evicted object was created through ray.put and the originating - * task - * is still executing, it's very likely that the workload will hang and the - * worker needs to be restarted. Else, the reconstruction behavior is the - * same as for other evicted objects */ - done_callback = reconstruct_put_task_update_callback; - } else { - done_callback = reconstruct_task_update_callback; - } - /* If there are no other instances of the task running, it's safe for us to - * claim responsibility for reconstruction. */ - task_table_test_and_update(state->db, task_id, DBClientID::nil(), - (TaskStatus::DONE | TaskStatus::LOST), - TaskStatus::RECONSTRUCTING, NULL, done_callback, - state); -} - -void reconstruct_failed_result_lookup_callback(ObjectID reconstruct_object_id, - TaskID task_id, - bool is_put, - void *user_context) { - if (task_id.is_nil()) { - /* NOTE(swang): For some reason, the result table update sometimes happens - * after this lookup returns, possibly due to concurrent clients. In most - * cases, this is okay because the initial execution is probably still - * pending, so for now, we log a warning and suppress reconstruction. */ - RAY_LOG(WARNING) << "No task information found for object during " - << "reconstruction (no object entry yet)"; - return; - } - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - /* If the task failed to finish, it's safe for us to claim responsibility for - * reconstruction. */ - task_table_test_and_update(state->db, task_id, DBClientID::nil(), - TaskStatus::LOST, TaskStatus::RECONSTRUCTING, NULL, - reconstruct_task_update_callback, state); -} - -void reconstruct_object_lookup_callback( - ObjectID reconstruct_object_id, - bool never_created, - const std::vector &manager_ids, - void *user_context) { - RAY_LOG(DEBUG) << "Manager count was " << manager_ids.size(); - /* Only continue reconstruction if we find that the object doesn't exist on - * any nodes. NOTE: This codepath is not responsible for checking if the - * object table entry is up-to-date. */ - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - /* Look up the task that created the object in the result table. */ - if (never_created) { - /* If the object has not been created yet, we reconstruct the object if and - * only if the task that created the object failed to complete. */ - result_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_failed_result_lookup_callback, - (void *) state); - } else { - /* If the object has been created, filter out the dead plasma managers that - * have it. */ - size_t num_live_managers = 0; - for (auto manager_id : manager_ids) { - DBClient manager = db_client_table_cache_get(state->db, manager_id); - if (manager.is_alive) { - num_live_managers++; - } - } - /* If the object was created, but all plasma managers that had the object - * either evicted it or failed, we reconstruct the object if and only if - * there are no other instances of the task running. */ - if (num_live_managers == 0) { - result_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_evicted_result_lookup_callback, - (void *) state); - } - } -} - -void reconstruct_object(LocalSchedulerState *state, - ObjectID reconstruct_object_id) { - RAY_LOG(DEBUG) << "Starting reconstruction"; - /* If the object is locally available, no need to reconstruct. */ - if (object_locally_available(state->algorithm_state, reconstruct_object_id)) { - return; - } - /* Determine if reconstruction is necessary by checking if the object exists - * on a node. */ - RAY_CHECK(state->db != NULL); - object_table_lookup(state->db, reconstruct_object_id, NULL, - reconstruct_object_lookup_callback, (void *) state); -} - -void handle_client_register( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const ray::local_scheduler::protocol::RegisterClientRequest *message) { - /* Make sure this worker hasn't already registered. */ - RAY_CHECK(!worker->registered); - worker->registered = true; - worker->is_worker = message->is_worker(); - RAY_CHECK(worker->client_id.is_nil()); - worker->client_id = from_flatbuf(*message->client_id()); - - /* Register the worker or driver. */ - if (worker->is_worker) { - /* Update the actor mapping with the actor ID of the worker (if an actor is - * running on the worker). */ - worker->pid = message->worker_pid(); - /* Register worker process id with the scheduler. */ - /* Determine if this worker is one of our child processes. */ - RAY_LOG(DEBUG) << "PID is " << worker->pid; - auto it = std::find(state->child_pids.begin(), state->child_pids.end(), - worker->pid); - if (it != state->child_pids.end()) { - /* If this worker is one of our child processes, mark it as a child so - * that we know that we can wait for the process to exit during - * cleanup. */ - worker->is_child = true; - state->child_pids.erase(it); - RAY_LOG(DEBUG) << "Found matching child pid " << worker->pid; - } - } else { - /* Register the driver. Currently we don't do anything here. */ - } -} - -void handle_driver_removed_callback(WorkerID driver_id, void *user_context) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - - /* Kill any actors that were created by the removed driver, and kill any - * workers that are currently running tasks from the dead driver. */ - auto it = state->workers.begin(); - while (it != state->workers.end()) { - /* Increment the iterator by one before calling kill_worker, because - * kill_worker will invalidate the iterator. Note that this requires - * knowledge of the particular container that we are iterating over (in this - * case it is a list). */ - auto next_it = it; - next_it++; - - ActorID actor_id = (*it)->actor_id; - Task *task = (*it)->task_in_progress; - - if (!actor_id.is_nil()) { - /* This is an actor. */ - RAY_CHECK(state->actor_mapping.count(actor_id) == 1); - if (state->actor_mapping[actor_id].driver_id == driver_id) { - /* This actor was created by the removed driver, so kill the actor. */ - RAY_LOG(DEBUG) << "Killing an actor for a removed driver."; - kill_worker(state, *it, false, true); - } - } else if (task != NULL) { - TaskSpec *spec = Task_task_execution_spec(task)->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - RAY_LOG(DEBUG) << "Killing a worker executing a task for a removed " - << "driver."; - kill_worker(state, *it, false, true); - } - } - - it = next_it; - } - - /* Add the driver to a list of dead drivers. */ - state->removed_drivers.insert(driver_id); - - /* Notify the scheduling algorithm that the driver has been removed. It should - * remove tasks for that driver from its data structures. */ - handle_driver_removed(state, state->algorithm_state, driver_id); -} - -void handle_client_disconnect(LocalSchedulerState *state, - LocalSchedulerClient *worker) { - if (!worker->registered || worker->is_worker) { - } else { - /* In this case, a driver is disconecting. */ - driver_table_send_driver_death(state->db, worker->client_id, NULL); - } - /* Suppress the warning message if the worker already disconnected. */ - kill_worker(state, worker, false, worker->disconnected); -} - -void handle_get_actor_frontier(LocalSchedulerState *state, - LocalSchedulerClient *worker, - ActorID actor_id) { - auto task_counters = - get_actor_task_counters(state->algorithm_state, actor_id); - auto frontier = get_actor_frontier(state->algorithm_state, actor_id); - - /* Build the ActorFrontier flatbuffer. */ - std::vector handle_vector; - std::vector task_counter_vector; - std::vector frontier_vector; - for (auto handle : task_counters) { - handle_vector.push_back(handle.first); - task_counter_vector.push_back(handle.second); - frontier_vector.push_back(frontier[handle.first]); - } - flatbuffers::FlatBufferBuilder fbb; - auto reply = ray::local_scheduler::protocol::CreateActorFrontier( - fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, handle_vector), - fbb.CreateVector(task_counter_vector), to_flatbuf(fbb, frontier_vector)); - fbb.Finish(reply); - /* Respond with the built ActorFrontier. */ - if (write_message(worker->sock, - static_cast(MessageType::GetActorFrontierReply), - fbb.GetSize(), (uint8_t *) fbb.GetBufferPointer()) < 0) { - if (errno == EPIPE || errno == EBADF) { - /* Something went wrong, so kill the worker. */ - kill_worker(state, worker, false, false); - RAY_LOG(WARNING) << "Failed to return actor frontier to worker on fd " - << worker->sock << ". The client may have hung up."; - } else { - RAY_LOG(FATAL) << "Failed to give task to client on fd " << worker->sock; - } - } -} - -void handle_set_actor_frontier( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - ray::local_scheduler::protocol::ActorFrontier const &frontier) { - /* Parse the ActorFrontier flatbuffer. */ - ActorID actor_id = from_flatbuf(*frontier.actor_id()); - std::unordered_map task_counters; - std::unordered_map frontier_dependencies; - for (size_t i = 0; i < frontier.handle_ids()->size(); ++i) { - ActorID handle_id = from_flatbuf(*frontier.handle_ids()->Get(i)); - task_counters[handle_id] = frontier.task_counters()->Get(i); - frontier_dependencies[handle_id] = - from_flatbuf(*frontier.frontier_dependencies()->Get(i)); - } - /* Set the actor's frontier. */ - set_actor_task_counters(state->algorithm_state, actor_id, task_counters); - set_actor_frontier(state, state->algorithm_state, actor_id, - frontier_dependencies); -} - -void process_message(event_loop *loop, - int client_sock, - void *context, - int events) { - int64_t start_time = current_time_ms(); - - LocalSchedulerClient *worker = (LocalSchedulerClient *) context; - LocalSchedulerState *state = worker->local_scheduler_state; - - int64_t type; - read_vector(client_sock, &type, state->input_buffer); - uint8_t *input = state->input_buffer.data(); - - RAY_LOG(DEBUG) << "New event of type " << type; - - switch (type) { - case static_cast(MessageType::SubmitTask): { - auto message = - flatbuffers::GetRoot( - input); - TaskExecutionSpec execution_spec = - TaskExecutionSpec(from_flatbuf(*message->execution_dependencies()), - (TaskSpec *) message->task_spec()->data(), - message->task_spec()->size()); - /* Set the tasks's local scheduler entrypoint time. */ - execution_spec.SetLastTimeStamp(current_time_ms()); - TaskSpec *spec = execution_spec.Spec(); - /* Update the result table, which holds mappings of object ID -> ID of the - * task that created it. */ - if (state->db != NULL) { - TaskID task_id = TaskSpec_task_id(spec); - for (int64_t i = 0; i < TaskSpec_num_returns(spec); ++i) { - ObjectID return_id = TaskSpec_return(spec, i); - result_table_add(state->db, return_id, task_id, false, NULL, NULL, - NULL); - } - } - - /* Handle the task submission. */ - if (TaskSpec_actor_id(spec).is_nil()) { - handle_task_submitted(state, state->algorithm_state, execution_spec); - } else { - handle_actor_task_submitted(state, state->algorithm_state, - execution_spec); - } - } break; - case static_cast(MessageType::TaskDone): { - } break; - case static_cast(MessageType::DisconnectClient): { - finish_task(state, worker); - RAY_CHECK(!worker->disconnected); - worker->disconnected = true; - /* If the disconnected worker was not an actor, start a new worker to make - * sure there are enough workers in the pool. */ - if (worker->actor_id.is_nil()) { - start_worker(state); - } - } break; - case static_cast(MessageType::EventLogMessage): { - /* Parse the message. */ - auto message = - flatbuffers::GetRoot( - input); - if (state->db != NULL) { - RayLogger_log_event(state->db, (uint8_t *) message->key()->data(), - message->key()->size(), - (uint8_t *) message->value()->data(), - message->value()->size(), message->timestamp()); - } - } break; - case static_cast(MessageType::RegisterClientRequest): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::RegisterClientRequest>(input); - handle_client_register(state, worker, message); - } break; - case static_cast(MessageType::GetTask): { - /* If this worker reports a completed task, account for resources. */ - finish_task(state, worker); - /* Let the scheduling algorithm process the fact that there is an available - * worker. */ - if (worker->actor_id.is_nil()) { - handle_worker_available(state, state->algorithm_state, worker); - } else { - handle_actor_worker_available(state, state->algorithm_state, worker); - } - } break; - case static_cast(MessageType::ReconstructObjects): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::ReconstructObjects>(input); - RAY_CHECK(!message->fetch_only()); - if (worker->task_in_progress != NULL && !worker->is_blocked) { - /* If the worker was executing a task (i.e. non-driver) and it wasn't - * already blocked on an object that's not locally available, update its - * state to blocked. */ - worker->is_blocked = true; - // Return the CPU resources that the blocked worker was using, but not - // other resources. If the worker is an actor, this will not return the - // CPU resources that the worker has acquired for its lifetime. It will - // only return the ones associated with the current method. - TaskSpec *spec = - Task_task_execution_spec(worker->task_in_progress)->Spec(); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - release_resources(state, worker, cpu_resources); - /* Let the scheduling algorithm process the fact that the worker is - * blocked. */ - if (worker->actor_id.is_nil()) { - handle_worker_blocked(state, state->algorithm_state, worker); - } else { - handle_actor_worker_blocked(state, state->algorithm_state, worker); - } - print_worker_info("Reconstructing", state->algorithm_state); - } - RAY_CHECK(message->object_ids()->size() == 1); - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(0)); - reconstruct_object(state, object_id); - } break; - case static_cast(CommonMessageType::DISCONNECT_CLIENT): { - RAY_LOG(DEBUG) << "Disconnecting client on fd " << client_sock; - handle_client_disconnect(state, worker); - } break; - case static_cast(MessageType::NotifyUnblocked): { - /* TODO(rkn): A driver may call this as well, right? */ - if (worker->task_in_progress != NULL) { - /* If the worker was executing a task (i.e. non-driver), update its - * state to not blocked. */ - RAY_CHECK(worker->is_blocked); - worker->is_blocked = false; - /* Lease back the CPU resources that the blocked worker needs (note that - * it never released its GPU resources). TODO(swang): Leasing back the - * resources to blocked workers can cause us to transiently exceed the - * maximum number of resources. This could be fixed by having blocked - * workers explicitly yield and wait to be given back resources before - * continuing execution. */ - TaskSpec *spec = - Task_task_execution_spec(worker->task_in_progress)->Spec(); - std::unordered_map cpu_resources; - cpu_resources["CPU"] = TaskSpec_get_required_resource(spec, "CPU"); - acquire_resources(state, worker, cpu_resources); - /* Let the scheduling algorithm process the fact that the worker is - * unblocked. */ - if (worker->actor_id.is_nil()) { - handle_worker_unblocked(state, state->algorithm_state, worker); - } else { - handle_actor_worker_unblocked(state, state->algorithm_state, worker); - } - } - print_worker_info("Worker unblocked", state->algorithm_state); - } break; - case static_cast(MessageType::PutObject): { - auto message = - flatbuffers::GetRoot(input); - result_table_add(state->db, from_flatbuf(*message->object_id()), - from_flatbuf(*message->task_id()), true, NULL, NULL, NULL); - } break; - case static_cast(MessageType::GetActorFrontierRequest): { - auto message = flatbuffers::GetRoot< - ray::local_scheduler::protocol::GetActorFrontierRequest>(input); - ActorID actor_id = from_flatbuf(*message->actor_id()); - handle_get_actor_frontier(state, worker, actor_id); - } break; - case static_cast(MessageType::SetActorFrontier): { - auto message = - flatbuffers::GetRoot( - input); - handle_set_actor_frontier(state, worker, *message); - } break; - default: - /* This code should be unreachable. */ - RAY_CHECK(0); - } - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "process_message of type " << type << " took " - << end_time - start_time << " milliseconds."; - } -} - -void new_client_connection(event_loop *loop, - int listener_sock, - void *context, - int events) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - int new_socket = accept_client(listener_sock); - /* Create a struct for this worker. This will be freed when we free the local - * scheduler state. */ - LocalSchedulerClient *worker = new LocalSchedulerClient(); - worker->sock = new_socket; - worker->registered = false; - worker->disconnected = false; - /* We don't know whether this is a worker or not, so just initialize is_worker - * to false. */ - worker->is_worker = true; - worker->client_id = WorkerID::nil(); - worker->task_in_progress = NULL; - worker->is_blocked = false; - worker->pid = 0; - worker->is_child = false; - worker->actor_id = ActorID::nil(); - worker->local_scheduler_state = state; - state->workers.push_back(worker); - event_loop_add_file(loop, new_socket, EVENT_LOOP_READ, process_message, - worker); - RAY_LOG(DEBUG) << "new connection with fd " << new_socket; -} - -/* We need this code so we can clean up when we get a SIGTERM signal. */ - -LocalSchedulerState *g_state = NULL; - -void signal_handler(int signal) { - RAY_LOG(DEBUG) << "Signal was " << signal; - if (signal == SIGTERM) { - /* NOTE(swang): This call removes the SIGTERM handler to ensure that we - * free the local scheduler state at most once. If another SIGTERM is - * caught during this call, there is the possibility of orphan worker - * processes. */ - if (g_state) { - LocalSchedulerState_free(g_state); - } - exit(0); - } -} - -/* End of the cleanup code. */ - -void handle_task_scheduled_callback(Task *original_task, - void *subscribe_context) { - LocalSchedulerState *state = (LocalSchedulerState *) subscribe_context; - TaskExecutionSpec *execution_spec = Task_task_execution_spec(original_task); - TaskSpec *spec = execution_spec->Spec(); - - /* Set the tasks's local scheduler entrypoint time. */ - execution_spec->SetLastTimeStamp(current_time_ms()); - - /* If the driver for this task has been removed, then don't bother telling the - * scheduling algorithm. */ - WorkerID driver_id = TaskSpec_driver_id(spec); - if (!is_driver_alive(state, driver_id)) { - RAY_LOG(DEBUG) << "Ignoring scheduled task for removed driver."; - return; - } - - if (TaskSpec_actor_id(spec).is_nil()) { - /* This task does not involve an actor. Handle it normally. */ - handle_task_scheduled(state, state->algorithm_state, *execution_spec); - } else { - /* This task involves an actor. Call the scheduling algorithm's actor - * handler. */ - handle_actor_task_scheduled(state, state->algorithm_state, *execution_spec); - } -} - -/** - * Process a notification about the creation of a new actor. Use this to update - * the mapping from actor ID to the local scheduler ID of the local scheduler - * that is responsible for the actor. If this local scheduler is responsible for - * the actor, then launch a new worker process to create that actor. - * - * @param actor_id The ID of the actor being created. - * @param local_scheduler_id The ID of the local scheduler that is responsible - * for creating the actor. - * @param context The context for this callback. - * @return Void. - */ -void handle_actor_creation_callback(const ActorID &actor_id, - const WorkerID &driver_id, - const DBClientID &local_scheduler_id, - void *context) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - - /* If the driver has been removed, don't bother doing anything. */ - if (state->removed_drivers.count(driver_id) == 1) { - return; - } - - // TODO(rkn): If we do not have perfect task suppression and it is possible - // for a task to be executed simultaneously on two nodes, then we will need to - // detect and handle that case. - - if (state->actor_mapping.count(actor_id) != 0) { - // This actor already exists. - auto it = state->actor_mapping.find(actor_id); - if (it->second.local_scheduler_id == get_db_client_id(state->db)) { - // TODO(rkn): The actor was previously assigned to this local scheduler. - // We should kill the actor here if it is still around. Also, if it hasn't - // registered yet, we should keep track of its PID so we can kill it - // anyway. - // TODO(swang): Evict actor dummy objects as part of actor cleanup. - } - } - - /* Create a new entry and add it to the actor mapping table. TODO(rkn): - * Currently this is never removed (except when the local scheduler state is - * deleted). */ - ActorMapEntry entry; - entry.local_scheduler_id = local_scheduler_id; - entry.driver_id = driver_id; - state->actor_mapping[actor_id] = entry; - - /* Let the scheduling algorithm process the fact that a new actor has been - * created. */ - handle_actor_creation_notification(state, state->algorithm_state, actor_id); -} - -int heartbeat_handler(event_loop *loop, timer_id id, void *context) { - LocalSchedulerState *state = (LocalSchedulerState *) context; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - // Spillback policy invocation is synchronized with the heartbeats. - spillback_tasks_handler(state); - - /* Check that the last heartbeat was not sent too long ago. */ - int64_t current_time = current_time_ms(); - RAY_CHECK(current_time >= state->previous_heartbeat_time); - if (current_time - state->previous_heartbeat_time > - RayConfig::instance().num_heartbeats_timeout() * - RayConfig::instance().heartbeat_timeout_milliseconds()) { - RAY_LOG(FATAL) << "The last heartbeat was sent " - << current_time - state->previous_heartbeat_time - << " milliseconds ago."; - } - state->previous_heartbeat_time = current_time; - - LocalSchedulerInfo info; - /* Ask the scheduling algorithm to fill out the scheduler info struct. */ - provide_scheduler_info(state, algorithm_state, &info); - /* Publish the heartbeat to all subscribers of the local scheduler table. */ - local_scheduler_table_send_info(state->db, &info, NULL); - /* Reset the timer. */ - return RayConfig::instance().heartbeat_timeout_milliseconds(); -} - -void start_server( - const char *node_ip_address, - const char *socket_name, - const char *redis_primary_addr, - int redis_primary_port, - const char *plasma_store_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_conf, - const char *start_worker_command, - int num_workers) { - /* Ignore SIGPIPE signals. If we don't do this, then when we attempt to write - * to a client that has already died, the local scheduler could die. */ - signal(SIGPIPE, SIG_IGN); - /* Ignore SIGCHLD signals. If we don't do this, then worker processes will - * become zombies instead of dying gracefully. */ - signal(SIGCHLD, SIG_IGN); - int fd = bind_ipc_sock(socket_name, true); - event_loop *loop = event_loop_create(); - g_state = LocalSchedulerState_init( - node_ip_address, loop, redis_primary_addr, redis_primary_port, - socket_name, plasma_store_socket_name, plasma_manager_socket_name, - plasma_manager_address, global_scheduler_exists, static_resource_conf, - start_worker_command, num_workers); - /* Register a callback for registering new clients. */ - event_loop_add_file(loop, fd, EVENT_LOOP_READ, new_client_connection, - g_state); - /* Subscribe to receive notifications about tasks that are assigned to this - * local scheduler by the global scheduler or by other local schedulers. - * TODO(rkn): we also need to get any tasks that were assigned to this local - * scheduler before the call to subscribe. */ - if (g_state->db != NULL) { - task_table_subscribe(g_state->db, get_db_client_id(g_state->db), - TaskStatus::SCHEDULED, handle_task_scheduled_callback, - g_state, NULL, NULL, NULL); - } - /* Subscribe to notifications about newly created actors. */ - if (g_state->db != NULL) { - actor_notification_table_subscribe( - g_state->db, handle_actor_creation_callback, g_state, NULL); - } - /* Subscribe to notifications about removed drivers. */ - if (g_state->db != NULL) { - driver_table_subscribe(g_state->db, handle_driver_removed_callback, g_state, - NULL); - } - /* Create a timer for publishing information about the load on the local - * scheduler to the local scheduler table. This message also serves as a - * heartbeat. */ - if (g_state->db != NULL) { - event_loop_add_timer(loop, - RayConfig::instance().heartbeat_timeout_milliseconds(), - heartbeat_handler, g_state); - } - /* Listen for new and deleted db clients. */ - if (g_state->db != NULL) { - db_client_table_cache_init(g_state->db); - } - /* Create a timer for fetching queued tasks' missing object dependencies. */ - event_loop_add_timer( - loop, RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(), - fetch_object_timeout_handler, g_state); - /* Create a timer for initiating the reconstruction of tasks' missing object - * dependencies. */ - event_loop_add_timer( - loop, RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(), - reconstruct_object_timeout_handler, g_state); - // Create a timer for rerunning actor creation tasks for actor tasks that are - // cached locally. - event_loop_add_timer( - loop, RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(), - rerun_actor_creation_tasks_timeout_handler, g_state); - /* Run event loop. */ - event_loop_run(loop); -} - -/* Only declare the main function if we are not in testing mode, since the test - * suite has its own declaration of main. */ -#ifndef LOCAL_SCHEDULER_TEST -int main(int argc, char *argv[]) { - InitShutdownRAII ray_log_shutdown_raii( - ray::RayLog::StartRayLog, ray::RayLog::ShutDownRayLog, argv[0], - ray::RayLogLevel::INFO, /*log_dir=*/""); - ray::RayLog::InstallFailureSignalHandler(); - signal(SIGTERM, signal_handler); - /* Path of the listening socket of the local scheduler. */ - char *scheduler_socket_name = NULL; - /* IP address and port of the primary redis instance. */ - char *redis_primary_addr_port = NULL; - /* Socket name for the local Plasma store. */ - char *plasma_store_socket_name = NULL; - /* Socket name for the local Plasma manager. */ - char *plasma_manager_socket_name = NULL; - /* Address for the plasma manager associated with this local scheduler - * instance. */ - char *plasma_manager_address = NULL; - /* The IP address of the node that this local scheduler is running on. */ - char *node_ip_address = NULL; - /* Comma-separated list of configured resource capabilities for this node. */ - char *static_resource_list = NULL; - std::unordered_map static_resource_conf; - /* The command to run when starting new workers. */ - char *start_worker_command = NULL; - /* The number of workers to start. */ - char *num_workers_str = NULL; - int c; - bool global_scheduler_exists = true; - while ((c = getopt(argc, argv, "s:r:p:m:ga:h:c:w:n:")) != -1) { - switch (c) { - case 's': - scheduler_socket_name = optarg; - break; - case 'r': - redis_primary_addr_port = optarg; - break; - case 'p': - plasma_store_socket_name = optarg; - break; - case 'm': - plasma_manager_socket_name = optarg; - break; - case 'g': - global_scheduler_exists = false; - break; - case 'a': - plasma_manager_address = optarg; - break; - case 'h': - node_ip_address = optarg; - break; - case 'c': - static_resource_list = optarg; - break; - case 'w': - start_worker_command = optarg; - break; - case 'n': - num_workers_str = optarg; - break; - default: - RAY_LOG(FATAL) << "unknown option " << c; - } - } - if (!static_resource_list) { - RAY_LOG(FATAL) << "please specify a static resource list with the -c " - << "switch"; - } - // Parse the resource list. - std::istringstream resource_string(static_resource_list); - std::string resource_name; - std::string resource_quantity; - - while (std::getline(resource_string, resource_name, ',')) { - RAY_CHECK(std::getline(resource_string, resource_quantity, ',')); - // TODO(rkn): The line below could throw an exception. What should we do - // about this? - static_resource_conf[resource_name] = std::stod(resource_quantity); - } - - if (!scheduler_socket_name) { - RAY_LOG(FATAL) << "please specify socket for incoming connections with " - << "-s switch"; - } - if (!plasma_store_socket_name) { - RAY_LOG(FATAL) << "please specify socket for connecting to Plasma store " - << "with -p switch"; - } - if (!node_ip_address) { - RAY_LOG(FATAL) << "please specify the node IP address with -h switch"; - } - int num_workers = 0; - if (num_workers_str) { - num_workers = strtol(num_workers_str, NULL, 10); - if (num_workers < 0) { - RAY_LOG(FATAL) << "Number of workers must be nonnegative"; - } - } - - char redis_primary_addr[16]; - char *redis_addr = NULL; - int redis_port = -1; - if (!redis_primary_addr_port) { - /* Start the local scheduler without connecting to Redis. In this case, all - * submitted tasks will be queued and scheduled locally. */ - if (plasma_manager_socket_name) { - RAY_LOG(FATAL) << "if a plasma manager socket name is provided with the " - << "-m switch, then a redis address must be provided with " - << "the -r switch"; - } - } else { - int redis_primary_port; - /* Parse the primary Redis address into an IP address and a port. */ - if (parse_ip_addr_port(redis_primary_addr_port, redis_primary_addr, - &redis_primary_port) == -1) { - RAY_LOG(FATAL) << "if a redis address is provided with the -r switch, it " - << "should be formatted like 127.0.0.1:6379"; - } - if (!plasma_manager_socket_name) { - RAY_LOG(FATAL) << "please specify socket for connecting to Plasma " - << "manager with -m switch"; - } - redis_addr = redis_primary_addr; - redis_port = redis_primary_port; - } - - start_server(node_ip_address, scheduler_socket_name, redis_addr, redis_port, - plasma_store_socket_name, plasma_manager_socket_name, - plasma_manager_address, global_scheduler_exists, - static_resource_conf, start_worker_command, num_workers); -} -#endif diff --git a/src/local_scheduler/local_scheduler.h b/src/local_scheduler/local_scheduler.h deleted file mode 100644 index 39c7523fe7ed..000000000000 --- a/src/local_scheduler/local_scheduler.h +++ /dev/null @@ -1,176 +0,0 @@ -#ifndef LOCAL_SCHEDULER_H -#define LOCAL_SCHEDULER_H - -#include "event_loop.h" -#include "local_scheduler_shared.h" -#include "task.h" - -/** - * Establish a connection to a new client. - * - * @param loop Event loop of the local scheduler. - * @param listener_socket Socket the local scheduler is listening on for new - * client requests. - * @param context State of the local scheduler. - * @param events Flag for events that are available on the listener socket. - * @return Void. - */ -void new_client_connection(event_loop *loop, - int listener_sock, - void *context, - int events); - -/** - * Check if a driver is still alive. - * - * @param driver_id The ID of the driver. - * @return True if the driver is still alive and false otherwise. - */ -bool is_driver_alive(WorkerID driver_id); - -/** - * This function can be called by the scheduling algorithm to assign a task - * to a worker. - * - * @param info - * @param task The task that is submitted to the worker. - * @param worker The worker to assign the task to. - * @return Void. - */ -void assign_task_to_worker(LocalSchedulerState *state, - TaskExecutionSpec &task, - LocalSchedulerClient *worker); - -/* - * This function is called whenever a task has finished on one of the workers. - * It updates the resource accounting and the global state store. - * - * @param state The local scheduler state. - * @param worker The worker that finished the task. - * @return Void. - */ -void finish_task(LocalSchedulerState *state, LocalSchedulerClient *worker); - -/** - * This is the callback that is used to process a notification from the Plasma - * store that an object has been sealed. - * - * @param loop The local scheduler's event loop. - * @param client_sock The file descriptor to read the notification from. - * @param context The local scheduler state. - * @param events - * @return Void. - */ -void process_plasma_notification(event_loop *loop, - int client_sock, - void *context, - int events); - -/** - * Reconstruct an object. If the object does not exist on any nodes, according - * to the state tables, and if the object is not already being reconstructed, - * this triggers a single reexecution of the task that originally created the - * object. - * - * @param state The local scheduler state. - * @param object_id The ID of the object to reconstruct. - * @return Void. - */ -void reconstruct_object(LocalSchedulerState *state, ObjectID object_id); - -void print_resource_info(const LocalSchedulerState *s, const TaskSpec *spec); - -/** - * Kill a worker, if it is a child process, and clean up all of its associated - * state. Note that this function is also called on drivers, but it should not - * actually send a kill signal to drivers. - * - * @param state The local scheduler state. - * @param worker The local scheduler client to kill. - * @param wait A boolean representing whether to wait for the killed worker to - * exit. - * @param suppress_warning A bool that is true if we should not warn the driver, - * and false otherwise. This should only be true when a driver is - * removed. - * @return Void. - */ -void kill_worker(LocalSchedulerState *state, - LocalSchedulerClient *worker, - bool wait, - bool suppress_warning); - -/** - * Start a worker. This forks a new worker process that can be added to the - * pool of available workers, pending registration of its PID with the local - * scheduler. - * - * @param state The local scheduler state. - * @param Void. - */ -void start_worker(LocalSchedulerState *state); - -/** - * Check if a certain quantity of dynamic resources are available. If num_cpus - * is 0, we ignore the dynamic number of available CPUs (which may be negative). - * - * @param state The state of the local scheduler. - * @param resources The resources to check. - * @return True if there are enough CPUs and GPUs and false otherwise. - */ -bool check_dynamic_resources( - LocalSchedulerState *state, - const std::unordered_map &resources); - -/** - * Acquire additional resources (CPUs and GPUs) for a worker. - * - * @param state The local scheduler state. - * @param worker The worker who is acquiring resources. - * @param resources The resources to acquire. - * @return Void. - */ -void acquire_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources); - -/** - * Return resources (CPUs and GPUs) being used by a worker to the local - * scheduler. - * - * @param state The local scheduler state. - * @param worker The worker who is returning resources. - * @param resources The resources to release. - * @return Void. - */ -void release_resources( - LocalSchedulerState *state, - LocalSchedulerClient *worker, - const std::unordered_map &resources); - -/** The following methods are for testing purposes only. */ -#ifdef LOCAL_SCHEDULER_TEST -LocalSchedulerState *LocalSchedulerState_init( - const char *node_ip_address, - event_loop *loop, - const char *redis_addr, - int redis_port, - const char *local_scheduler_socket_name, - const char *plasma_manager_socket_name, - const char *plasma_store_socket_name, - const char *plasma_manager_address, - bool global_scheduler_exists, - const std::unordered_map &static_resource_vector, - const char *worker_path, - int num_workers); - -SchedulingAlgorithmState *get_algorithm_state(LocalSchedulerState *state); - -void process_message(event_loop *loop, - int client_sock, - void *context, - int events); - -#endif - -#endif /* LOCAL_SCHEDULER_H */ diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc deleted file mode 100644 index 89d6c8d6df56..000000000000 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ /dev/null @@ -1,1851 +0,0 @@ -#include "local_scheduler_algorithm.h" - -#include -#include -#include - -#include "state/task_table.h" -#include "state/actor_notification_table.h" -#include "state/db_client_table.h" -#include "state/error_table.h" -#include "state/local_scheduler_table.h" -#include "state/object_table.h" -#include "local_scheduler_shared.h" -#include "local_scheduler.h" -#include "common/task.h" - -/* Declared for convenience. */ -void remove_actor(SchedulingAlgorithmState *algorithm_state, ActorID actor_id); - -void give_task_to_global_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -void give_task_to_local_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - DBClientID local_scheduler_id); - -void clear_missing_dependencies(SchedulingAlgorithmState *algorithm_state, - std::list::iterator it); - -/** A data structure used to track which objects are available locally and - * which objects are being actively fetched. Objects of this type are used for - * both the scheduling algorithm state's local_objects and remote_objects - * tables. An ObjectEntry should be in at most one of the tables and not both - * simultaneously. */ -struct ObjectEntry { - /** A vector of tasks dependent on this object. These tasks are a subset of - * the tasks in the waiting queue. Each element actually stores a reference - * to the corresponding task's queue entry in waiting queue, for fast - * deletion when all of the task's dependencies become available. */ - std::vector::iterator> dependent_tasks; - /** Whether or not to request a transfer of this object. This should be set - * to true for all objects except for actor dummy objects, where the object - * must be generated by executing the task locally. */ - bool request_transfer; -}; - -/** This struct contains information about a specific actor. This struct will be - * used inside of a hash table. */ -typedef struct { - /** The number of tasks that have been executed on this actor so far, per - * handle. This is used to guarantee execution of tasks on actors in the - * order that the tasks were submitted, per handle. Tasks from different - * handles to the same actor may be interleaved. */ - std::unordered_map task_counters; - /** These are the execution dependencies that make up the frontier of the - * actor's runnable tasks. For each actor handle, we store the object ID - * that represents the execution dependency for the next runnable task - * submitted by that handle. */ - std::unordered_map frontier_dependencies; - /** The return value of the most recently executed task. The next task to - * execute should take this as an execution dependency at dispatch time. Set - * to nil if there are no execution dependencies (e.g., this is the first - * task to execute). */ - ObjectID execution_dependency; - /** A queue of tasks to be executed on this actor. The tasks will be sorted by - * the order of their actor counters. */ - std::list *task_queue; - /** The worker that the actor is running on. */ - LocalSchedulerClient *worker; - /** True if the worker is available and false otherwise. */ - bool worker_available; -} LocalActorInfo; - -/** Part of the local scheduler state that is maintained by the scheduling - * algorithm. */ -struct SchedulingAlgorithmState { - /** An array of pointers to tasks that are waiting for dependencies. */ - std::list *waiting_task_queue; - /** An array of pointers to tasks whose dependencies are ready but that are - * waiting to be assigned to a worker. */ - std::list *dispatch_task_queue; - /** This is a hash table from actor ID to information about that actor. In - * particular, a queue of tasks that are waiting to execute on that actor. - * This is only used for actors that exist locally. */ - std::unordered_map local_actor_infos; - /** This is a set of the IDs of the actors that have tasks waiting to run. - * The purpose is to make it easier to dispatch tasks without looping over - * all of the actors. Note that this is an optimization and is not strictly - * necessary. */ - std::unordered_set actors_with_pending_tasks; - /** A vector of actor tasks that have been submitted but this local scheduler - * doesn't know which local scheduler is responsible for them, so cannot - * assign them to the correct local scheduler yet. Whenever a notification - * about a new local scheduler arrives, we will resubmit all of these tasks - * locally. */ - std::vector cached_submitted_actor_tasks; - /** An array of pointers to workers in the worker pool. These are workers - * that have registered a PID with us and that are now waiting to be - * assigned a task to execute. */ - std::vector available_workers; - /** An array of pointers to workers that are currently executing a task, - * unblocked. These are the workers that are leasing some number of - * resources. */ - std::vector executing_workers; - /** An array of pointers to workers that are currently executing a task, - * blocked on some object(s) that isn't available locally yet. These are the - * workers that are executing a task, but that have temporarily returned the - * task's required resources. */ - std::vector blocked_workers; - /** A hash map of the objects that are available in the local Plasma store. - * The key is the object ID. This information could be a little stale. */ - std::unordered_map local_objects; - /** A hash map of the objects that are not available locally. These are - * currently being fetched by this local scheduler. The key is the object - * ID. Every local_scheduler_fetch_timeout_milliseconds, a Plasma fetch - * request will be sent the object IDs in this table. Each entry also holds - * an array of queued tasks that are dependent on it. */ - std::unordered_map remote_objects; -}; - -SchedulingAlgorithmState *SchedulingAlgorithmState_init(void) { - SchedulingAlgorithmState *algorithm_state = new SchedulingAlgorithmState(); - /* Initialize the local data structures used for queuing tasks and workers. */ - algorithm_state->waiting_task_queue = new std::list(); - algorithm_state->dispatch_task_queue = new std::list(); - - return algorithm_state; -} - -void SchedulingAlgorithmState_free(SchedulingAlgorithmState *algorithm_state) { - /* Free all of the tasks in the waiting queue. */ - delete algorithm_state->waiting_task_queue; - /* Free all the tasks in the dispatch queue. */ - delete algorithm_state->dispatch_task_queue; - /* Remove all of the remaining actors. */ - while (algorithm_state->local_actor_infos.size() != 0) { - auto it = algorithm_state->local_actor_infos.begin(); - ActorID actor_id = it->first; - remove_actor(algorithm_state, actor_id); - } - /* Free the algorithm state. */ - delete algorithm_state; -} - -/** - * This is a helper method to check if a worker is in a vector of workers. - * - * @param worker_vector A vector of workers. - * @param The worker to look for in the vector. - * @return True if the worker is in the vector and false otherwise. - */ -bool worker_in_vector(std::vector &worker_vector, - LocalSchedulerClient *worker) { - auto it = std::find(worker_vector.begin(), worker_vector.end(), worker); - return it != worker_vector.end(); -} - -/** - * This is a helper method to remove a worker from a vector of workers if it is - * present in the vector. - * - * @param worker_vector A vector of workers. - * @param The worker to remove. - * @return True if the worker was removed and false otherwise. - */ -bool remove_worker_from_vector( - std::vector &worker_vector, - LocalSchedulerClient *worker) { - /* Find the worker in the list of executing workers. */ - auto it = std::find(worker_vector.begin(), worker_vector.end(), worker); - bool remove_worker = (it != worker_vector.end()); - if (remove_worker) { - /* Remove the worker from the list of workers. */ - using std::swap; - swap(*it, worker_vector.back()); - worker_vector.pop_back(); - } - return remove_worker; -} - -void provide_scheduler_info(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerInfo *info) { - info->total_num_workers = state->workers.size(); - /* TODO(swang): Provide separate counts for tasks that are waiting for - * dependencies vs tasks that are waiting to be assigned. */ - int64_t waiting_task_queue_length = - algorithm_state->waiting_task_queue->size(); - int64_t dispatch_task_queue_length = - algorithm_state->dispatch_task_queue->size(); - info->task_queue_length = - waiting_task_queue_length + dispatch_task_queue_length; - info->available_workers = algorithm_state->available_workers.size(); - /* Copy static and dynamic resource information. */ - info->dynamic_resources = state->dynamic_resources; - info->static_resources = state->static_resources; -} - -/** - * Create the LocalActorInfo struct for an actor worker that this local - * scheduler is responsible for. For a given actor, this will either be done - * when the first task for that actor arrives or when the worker running that - * actor connects to the local scheduler. - * - * @param algorithm_state The state of the scheduling algorithm. - * @param actor_id The actor ID of the actor being created. - * @param initial_execution_dependency The dummy object ID of the actor - * creation task. - * @param worker The worker struct for the worker that is running this actor. - * If the worker struct has not been created yet (meaning that the worker - * that is running this actor has not registered with the local scheduler - * yet, and so create_actor is being called because a task for that actor - * has arrived), then this should be NULL. - * @return Void. - */ -void create_actor(SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker) { - LocalActorInfo entry; - entry.task_counters[ActorHandleID::nil()] = 0; - entry.frontier_dependencies[ActorHandleID::nil()] = ObjectID::nil(); - /* The actor has not yet executed any tasks, so there are no execution - * dependencies for the next task to be scheduled. */ - entry.execution_dependency = initial_execution_dependency; - entry.task_queue = new std::list(); - entry.worker = worker; - entry.worker_available = false; - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 0); - algorithm_state->local_actor_infos[actor_id] = entry; - - /* Log some useful information about the actor that we created. */ - RAY_LOG(DEBUG) << "Creating actor with ID " << actor_id; -} - -void remove_actor(SchedulingAlgorithmState *algorithm_state, ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 1); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - - /* Log some useful information about the actor that we're removing. */ - size_t count = entry.task_queue->size(); - if (count > 0) { - RAY_LOG(WARNING) << "Removing actor with ID " << actor_id << " and " - << count << " remaining tasks."; - } - - entry.task_queue->clear(); - delete entry.task_queue; - /* Remove the entry from the hash table. */ - algorithm_state->local_actor_infos.erase(actor_id); - - /* Remove the actor ID from the set of actors with pending tasks. */ - algorithm_state->actors_with_pending_tasks.erase(actor_id); -} - -/** - * Dispatch a task to an actor if possible. - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param actor_id The ID of the actor corresponding to the worker. - * @return True if a task was dispatched to the actor and false otherwise. - */ -bool dispatch_actor_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - /* Make sure this worker actually is an actor. */ - RAY_CHECK(!actor_id.is_nil()); - /* Return if this actor doesn't have any pending tasks. */ - if (algorithm_state->actors_with_pending_tasks.find(actor_id) == - algorithm_state->actors_with_pending_tasks.end()) { - return false; - } - /* Make sure this actor belongs to this local scheduler. */ - if (state->actor_mapping.count(actor_id) != 1) { - /* The creation notification for this actor has not yet arrived at the local - * scheduler. This should be rare. */ - return false; - } - RAY_CHECK(state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)); - - /* Get the local actor entry for this actor. */ - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - - /* There should be some queued tasks for this actor. */ - RAY_CHECK(!entry.task_queue->empty()); - /* If the worker is not available, we cannot assign a task to it. */ - if (!entry.worker_available) { - return false; - } - - /* Check whether we can execute the first task in the queue. */ - auto task = entry.task_queue->begin(); - TaskSpec *spec = task->Spec(); - ActorHandleID next_task_handle_id = TaskSpec_actor_handle_id(spec); - /* We can only execute tasks in order of task_counter. */ - if (TaskSpec_actor_counter(spec) != - entry.task_counters[next_task_handle_id]) { - return false; - } - - /* If there are not enough resources available, we cannot assign the task. */ - RAY_CHECK(0 == TaskSpec_get_required_resource(spec, "GPU")); - if (!check_dynamic_resources(state, TaskSpec_get_required_resources(spec))) { - return false; - } - - /* Update the task's execution dependencies to reflect the actual execution - * order to support deterministic reconstruction. */ - /* NOTE(swang): The update of an actor task's execution dependencies is - * performed asynchronously. This means that if this local scheduler dies, we - * may lose updates that are in flight to the task table. We only guarantee - * deterministic reconstruction ordering for tasks whose updates are - * reflected in the task table. */ - std::vector ordered_execution_dependencies; - ordered_execution_dependencies.push_back(entry.execution_dependency); - task->SetExecutionDependencies(ordered_execution_dependencies); - - /* Assign the first task in the task queue to the worker and mark the worker - * as unavailable. */ - assign_task_to_worker(state, *task, entry.worker); - entry.execution_dependency = TaskSpec_actor_dummy_object(spec); - entry.worker_available = false; - /* Extend the frontier to include the assigned task. */ - entry.task_counters[next_task_handle_id] += 1; - entry.frontier_dependencies[next_task_handle_id] = entry.execution_dependency; - - /* Remove the task from the actor's task queue. */ - entry.task_queue->erase(task); - /* If there are no more tasks in the queue, then indicate that the actor has - * no tasks. */ - if (entry.task_queue->empty()) { - algorithm_state->actors_with_pending_tasks.erase(actor_id); - } - - return true; -} - -void handle_convert_worker_to_actor( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker) { - if (algorithm_state->local_actor_infos.count(actor_id) == 0) { - create_actor(algorithm_state, actor_id, initial_execution_dependency, - worker); - } else { - /* In this case, the LocalActorInfo struct was already been created by the - * first call to add_task_to_actor_queue. However, the worker field was not - * filled out, so fill out the correct worker field now. */ - algorithm_state->local_actor_infos[actor_id].worker = worker; - } - /* Increment the task counter for the creator's handle to account for the - * actor creation task. */ - auto &task_counters = - algorithm_state->local_actor_infos[actor_id].task_counters; - RAY_CHECK(task_counters[ActorHandleID::nil()] == 0); - task_counters[ActorHandleID::nil()]++; -} - -/** - * Finishes a killed task by inserting dummy objects for each of its returns. - */ -void finish_killed_task(LocalSchedulerState *state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - int64_t num_returns = TaskSpec_num_returns(spec); - for (int i = 0; i < num_returns; i++) { - ObjectID object_id = TaskSpec_return(spec, i); - std::shared_ptr data; - // TODO(ekl): this writes an invalid arrow object, which is sufficient to - // signal that the worker failed, but it would be nice to return more - // detailed failure metadata in the future. - arrow::Status status = - state->plasma_conn->Create(object_id.to_plasma_id(), 1, NULL, 0, &data); - if (!status.IsPlasmaObjectExists()) { - ARROW_CHECK_OK(status); - ARROW_CHECK_OK(state->plasma_conn->Seal(object_id.to_plasma_id())); - } - } - /* Mark the task as done. */ - if (state->db != NULL) { - Task *task = Task_alloc(execution_spec, TaskStatus::DONE, - get_db_client_id(state->db)); - // In most cases, task_table_update would be appropriate, however, it is - // possible in some cases that the task has not yet been added to the task - // table (e.g., if it is an actor task that is queued locally because the - // actor has not been created yet). - task_table_add_task(state->db, task, NULL, NULL, NULL); - } -} - -/** - * Insert a task queue entry into an actor's dispatch queue. The task is - * inserted in sorted order by task counter. If this is the first task - * scheduled to this actor and the worker process has not yet connected, then - * this also creates a LocalActorInfo entry for the actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param task_entry The task queue entry to add to the actor's queue. - * @return Void. - */ -void insert_actor_task_queue(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec task_entry) { - TaskSpec *spec = task_entry.Spec(); - /* Get the local actor entry for this actor. */ - ActorID actor_id = TaskSpec_actor_id(spec); - ActorHandleID task_handle_id = TaskSpec_actor_handle_id(spec); - int64_t task_counter = TaskSpec_actor_counter(spec); - - /* Fail the task immediately; it's destined for a dead actor. */ - if (state->removed_actors.find(actor_id) != state->removed_actors.end()) { - finish_killed_task(state, task_entry); - return; - } - - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - if (entry.task_counters.count(task_handle_id) == 0) { - entry.task_counters[task_handle_id] = 0; - } - /* Extend the frontier to include the new handle. */ - if (entry.frontier_dependencies.count(task_handle_id) == 0) { - RAY_CHECK(task_entry.ExecutionDependencies().size() == 1); - entry.frontier_dependencies[task_handle_id] = - task_entry.ExecutionDependencies()[0]; - } - - /* As a sanity check, the counter of the new task should be greater than the - * number of tasks that have executed on this actor so far (since we are - * guaranteeing in-order execution of the tasks on the actor). TODO(rkn): This - * check will fail if the fault-tolerance mechanism resubmits a task on an - * actor. */ - if (task_counter < entry.task_counters[task_handle_id]) { - RAY_LOG(INFO) << "A task that has already been executed has been " - << "resubmitted, so we are ignoring it. This should only " - << "happen during reconstruction."; - return; - } - - /* Insert the task spec to the actor's task queue in sorted order, per actor - * handle ID. Find the first task in the queue with a counter greater than - * the submitted task's and the same handle ID. */ - auto it = entry.task_queue->begin(); - for (; it != entry.task_queue->end(); it++) { - TaskSpec *pending_task_spec = it->Spec(); - /* Skip tasks submitted by a different handle. */ - if (!(task_handle_id == TaskSpec_actor_handle_id(pending_task_spec))) { - continue; - } - /* A duplicate task submitted by the same handle. */ - if (task_counter == TaskSpec_actor_counter(pending_task_spec)) { - RAY_LOG(INFO) << "A task was resubmitted, so we are ignoring it. This " - << "should only happen during reconstruction."; - return; - } - /* We found a task with the same handle ID and a greater task counter. */ - if (task_counter < TaskSpec_actor_counter(pending_task_spec)) { - break; - } - } - entry.task_queue->insert(it, std::move(task_entry)); - - /* Record the fact that this actor has a task waiting to execute. */ - algorithm_state->actors_with_pending_tasks.insert(actor_id); -} - -/** - * Queue a task to be dispatched for an actor. Update the task table for the - * queued task. TODO(rkn): Should we also update the task table in the case - * where the tasks are cached locally? - * - * @param state The state of the local scheduler. - * @param algorithm_state The state of the scheduling algorithm. - * @param spec The task spec to add. - * @param from_global_scheduler True if the task was assigned to this local - * scheduler by the global scheduler and false if it was submitted - * locally by a worker. - * @return Void. - */ -void queue_actor_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - TaskSpec *spec = execution_spec.Spec(); - ActorID actor_id = TaskSpec_actor_id(spec); - RAY_CHECK(!actor_id.is_nil()); - - /* Update the task table. */ - if (state->db != NULL) { - Task *task = Task_alloc(execution_spec, TaskStatus::QUEUED, - get_db_client_id(state->db)); - if (from_global_scheduler) { - /* If the task is from the global scheduler, it's already been added to - * the task table, so just update the entry. */ - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - /* Otherwise, this is the first time the task has been seen in the - * system (unless it's a resubmission of a previous task), so add the - * entry. */ - task_table_add_task(state->db, task, NULL, NULL, NULL); - } - } - - // Create a new task queue entry. This must come after the above block because - // insert_actor_task_queue may call task_table_update internally, which must - // come after the prior call to task_table_add_task. - TaskExecutionSpec copy = TaskExecutionSpec(&execution_spec); - insert_actor_task_queue(state, algorithm_state, std::move(copy)); -} - -/** - * Fetch a queued task's missing object dependency. The fetch request will be - * retried every local_scheduler_fetch_timeout_milliseconds until the object is - * available locally. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @param obj_id The ID of the object that the task is dependent on. - * @param request_transfer Whether to request a transfer of this object from - * other plasma managers. This should be set to false for execution - * dependencies, which should be fulfilled by executing the - * corresponding task locally. - * @returns Void. - */ -void fetch_missing_dependency( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it, - plasma::ObjectID obj_id, - bool request_transfer) { - if (algorithm_state->remote_objects.count(obj_id) == 0) { - /* We weren't actively fetching this object. Try the fetch once - * immediately. */ - if (state->plasma_conn->get_manager_fd() != -1) { - auto arrow_status = state->plasma_conn->Fetch(1, &obj_id); - if (!arrow_status.ok()) { - LocalSchedulerState_free(state); - /* TODO(swang): Local scheduler should also exit even if there are no - * pending fetches. This could be done by subscribing to the db_client - * table, or pinging the plasma manager in the heartbeat handler. */ - RAY_LOG(FATAL) << "Lost connection to the plasma manager, local " - << "scheduler is exiting. Error: " - << arrow_status.ToString(); - } - } - /* Create an entry and add it to the list of active fetch requests to - * ensure that the fetch actually happens. The entry will be moved to the - * hash table of locally available objects in handle_object_available when - * the object becomes available locally. It will get freed if the object is - * subsequently removed locally. */ - ObjectEntry entry; - entry.request_transfer = request_transfer; - algorithm_state->remote_objects[obj_id] = entry; - } - algorithm_state->remote_objects[obj_id].dependent_tasks.push_back( - task_entry_it); -} - -/** - * Fetch a queued task's missing object dependencies. The fetch requests will - * be retried every local_scheduler_fetch_timeout_milliseconds until all - * objects are available locally. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @returns Void. - */ -void fetch_missing_dependencies( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it) { - int64_t num_dependencies = task_entry_it->NumDependencies(); - int num_missing_dependencies = 0; - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = task_entry_it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task_entry_it->DependencyId(i, j); - /* If the entry is not yet available locally, record the dependency. */ - if (algorithm_state->local_objects.count(obj_id) == 0) { - /* Do not request a transfer from other plasma managers if this is an - * execution dependency. */ - bool request_transfer = task_entry_it->IsStaticDependency(i); - fetch_missing_dependency(state, algorithm_state, task_entry_it, - obj_id.to_plasma_id(), request_transfer); - ++num_missing_dependencies; - } - } - } - RAY_CHECK(num_missing_dependencies > 0); -} - -/** - * Clear a queued task's missing object dependencies. This is the inverse of - * fetch_missing_dependencies. - * TODO(swang): Test this function. - * - * @param algorithm_state The scheduling algorithm state. - * @param task_entry_it A reference to the task entry in the waiting queue. - * @returns Void. - */ -void clear_missing_dependencies( - SchedulingAlgorithmState *algorithm_state, - std::list::iterator task_entry_it) { - int64_t num_dependencies = task_entry_it->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = task_entry_it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task_entry_it->DependencyId(i, j); - /* If this object dependency is missing, remove this task from the - * object's list of dependent tasks. */ - auto entry = algorithm_state->remote_objects.find(obj_id); - if (entry != algorithm_state->remote_objects.end()) { - /* Find and remove the given task. */ - auto &dependent_tasks = entry->second.dependent_tasks; - for (auto dependent_task_it = dependent_tasks.begin(); - dependent_task_it != dependent_tasks.end();) { - if (*dependent_task_it == task_entry_it) { - dependent_task_it = dependent_tasks.erase(dependent_task_it); - } else { - dependent_task_it++; - } - } - /* If the missing object dependency has no more dependent tasks, then - * remove it. */ - if (dependent_tasks.empty()) { - algorithm_state->remote_objects.erase(entry); - } - } - } - } -} - -/** - * Check if all of the remote object arguments for a task are available in the - * local object store. - * - * @param algorithm_state The scheduling algorithm state. - * @param task Task specification of the task to check. - * @return bool This returns true if all of the remote object arguments for the - * task are present in the local object store, otherwise it returns - * false. - */ -bool can_run(SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &task) { - int64_t num_dependencies = task.NumDependencies(); - for (int i = 0; i < num_dependencies; ++i) { - int count = task.DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID obj_id = task.DependencyId(i, j); - if (algorithm_state->local_objects.count(obj_id) == 0) { - /* The object is not present locally, so this task cannot be scheduled - * right now. */ - return false; - } - } - } - return true; -} - -bool object_locally_available(SchedulingAlgorithmState *algorithm_state, - ObjectID object_id) { - return algorithm_state->local_objects.count(object_id) == 1; -} - -/* TODO(swang): This method is not covered by any valgrind tests. */ -int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - /* Only try the fetches if we are connected to the object store manager. */ - if (state->plasma_conn->get_manager_fd() == -1) { - RAY_LOG(INFO) - << "Local scheduler is not connected to a object store manager"; - return RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(); - } - - std::vector object_id_vec; - for (auto const &entry : state->algorithm_state->remote_objects) { - if (entry.second.request_transfer) { - object_id_vec.push_back(entry.first); - } - } - - ObjectID *object_ids = object_id_vec.data(); - int64_t num_object_ids = object_id_vec.size(); - - /* Divide very large fetch requests into smaller fetch requests so that a - * single fetch request doesn't block the plasma manager for a long time. */ - for (int64_t j = 0; j < num_object_ids; - j += RayConfig::instance().local_scheduler_fetch_request_size()) { - int num_objects_in_request = - std::min( - num_object_ids, - j + RayConfig::instance().local_scheduler_fetch_request_size()) - - j; - auto arrow_status = state->plasma_conn->Fetch( - num_objects_in_request, - reinterpret_cast(&object_ids[j])); - if (!arrow_status.ok()) { - LocalSchedulerState_free(state); - RAY_LOG(FATAL) << "Lost connection to the plasma manager, local " - << "scheduler is exiting. Error: " - << arrow_status.ToString(); - } - } - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "fetch_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - /* Wait at least local_scheduler_fetch_timeout_milliseconds before running - * this timeout handler again. But if we're waiting for a large number of - * objects, wait longer (e.g., 10 seconds for one million objects) so that we - * don't overwhelm the plasma manager. */ - return std::max( - RayConfig::instance().local_scheduler_fetch_timeout_milliseconds(), - int64_t(0.01 * num_object_ids)); -} - -/* TODO(swang): This method is not covered by any valgrind tests. */ -int reconstruct_object_timeout_handler(event_loop *loop, - timer_id id, - void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - - /* This vector is used to track which object IDs to reconstruct next. If the - * vector is empty, we repopulate it with all of the keys of the remote object - * table. During every pass through this handler, we call reconstruct on up to - * max_num_to_reconstruct elements of the vector (after first checking that - * the object IDs are still missing). */ - static std::vector object_ids_to_reconstruct; - - /* If the set is empty, repopulate it. */ - if (object_ids_to_reconstruct.size() == 0) { - for (auto const &entry : state->algorithm_state->remote_objects) { - object_ids_to_reconstruct.push_back(entry.first); - } - } - - int64_t num_reconstructed = 0; - for (size_t i = 0; i < object_ids_to_reconstruct.size(); i++) { - ObjectID object_id = object_ids_to_reconstruct[i]; - /* Only call reconstruct if we are still missing the object. */ - if (state->algorithm_state->remote_objects.find(object_id) != - state->algorithm_state->remote_objects.end()) { - reconstruct_object(state, object_id); - } - num_reconstructed++; - if (num_reconstructed == RayConfig::instance().max_num_to_reconstruct()) { - break; - } - } - object_ids_to_reconstruct.erase( - object_ids_to_reconstruct.begin(), - object_ids_to_reconstruct.begin() + num_reconstructed); - - /* Print a warning if this method took too long. */ - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "reconstruct_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - return RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(); -} - -int rerun_actor_creation_tasks_timeout_handler(event_loop *loop, - timer_id id, - void *context) { - int64_t start_time = current_time_ms(); - - LocalSchedulerState *state = (LocalSchedulerState *) context; - - // Create a set of the dummy object IDs for the actor creation tasks to - // reconstruct. - std::unordered_set actor_dummy_objects; - for (auto const &execution_spec : - state->algorithm_state->cached_submitted_actor_tasks) { - ObjectID actor_creation_dummy_object_id = - TaskSpec_actor_creation_dummy_object_id(execution_spec.Spec()); - actor_dummy_objects.insert(actor_creation_dummy_object_id); - } - - // Issue reconstruct calls. - for (auto const &object_id : actor_dummy_objects) { - reconstruct_object(state, object_id); - } - - // Print a warning if this method took too long. - int64_t end_time = current_time_ms(); - if (end_time - start_time > - RayConfig::instance().max_time_for_handler_milliseconds()) { - RAY_LOG(WARNING) << "reconstruct_object_timeout_handler took " - << end_time - start_time << " milliseconds."; - } - - return RayConfig::instance() - .local_scheduler_reconstruction_timeout_milliseconds(); -} - -/** - * Return true if there are still some resources available and false otherwise. - * - * @param state The scheduler state. - * @return True if there are still some resources and false if there are not. - */ -bool resources_available(LocalSchedulerState *state) { - bool resources_available = false; - for (auto const &resource_pair : state->dynamic_resources) { - if (resource_pair.second > 0) { - resources_available = true; - } - } - return resources_available; -} - -void spillback_tasks_handler(LocalSchedulerState *state) { - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - int64_t num_to_spillback = std::min( - static_cast(algorithm_state->dispatch_task_queue->size()), - RayConfig::instance().max_tasks_to_spillback()); - - auto it = algorithm_state->dispatch_task_queue->end(); - for (int64_t i = 0; i < num_to_spillback; i++) { - it--; - } - - for (int64_t i = 0; i < num_to_spillback; i++) { - it->IncrementSpillbackCount(); - // If an actor hasn't been created for a while, push a warning to the - // driver. - if (it->SpillbackCount() % - RayConfig::instance().actor_creation_num_spillbacks_warning() == - 0) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_is_actor_creation_task(spec)) { - std::ostringstream error_message; - error_message << "The actor with ID " - << TaskSpec_actor_creation_id(spec) << " is taking a " - << "while to be created. It is possible that the " - << "cluster does not have enough resources to place this " - << "actor (this may be normal while an autoscaling " - << "is scaling up). Consider reducing the number of " - << "actors created, or " - << "increasing the number of slots available by using " - << "the --num-cpus, --num-gpus, and --resources flags. " - << "The actor creation task is requesting "; - for (auto const &resource_pair : - TaskSpec_get_required_resources(spec)) { - error_message << resource_pair.second << " " << resource_pair.first - << " "; - } - push_error(state->db, TaskSpec_driver_id(spec), - ErrorIndex::ACTOR_NOT_CREATED, error_message.str()); - } - } - - give_task_to_global_scheduler(state, algorithm_state, *it); - // Dequeue the task. - it = algorithm_state->dispatch_task_queue->erase(it); - } -} - -/** - * Assign as many tasks from the dispatch queue as possible. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @return Void. - */ -void dispatch_tasks(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state) { - /* Assign as many tasks as we can, while there are workers available. */ - for (auto it = algorithm_state->dispatch_task_queue->begin(); - it != algorithm_state->dispatch_task_queue->end();) { - TaskSpec *spec = it->Spec(); - /* If there is a task to assign, but there are no more available workers in - * the worker pool, then exit. Ensure that there will be an available - * worker during a future invocation of dispatch_tasks. */ - if (algorithm_state->available_workers.size() == 0) { - if (state->child_pids.size() == 0) { - /* If there are no workers, including those pending PID registration, - * then we must start a new one to replenish the worker pool. */ - start_worker(state); - } - return; - } - - /* Terminate early if there are no more resources available. */ - if (!resources_available(state)) { - return; - } - - /* Skip to the next task if this task cannot currently be satisfied. */ - if (!check_dynamic_resources(state, - TaskSpec_get_required_resources(spec))) { - /* This task could not be satisfied -- proceed to the next task. */ - ++it; - continue; - } - - /* Dispatch this task to an available worker and dequeue the task. */ - RAY_LOG(DEBUG) << "Dispatching task"; - /* Get the last available worker in the available worker queue. */ - LocalSchedulerClient *worker = algorithm_state->available_workers.back(); - /* Tell the available worker to execute the task. */ - assign_task_to_worker(state, *it, worker); - /* Remove the worker from the available queue, and add it to the executing - * workers. */ - algorithm_state->available_workers.pop_back(); - algorithm_state->executing_workers.push_back(worker); - print_resource_info(state, spec); - /* Dequeue the task. */ - it = algorithm_state->dispatch_task_queue->erase(it); - } /* End for each task in the dispatch queue. */ -} - -/** - * Attempt to dispatch both regular tasks and actor tasks. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @return Void. - */ -void dispatch_all_tasks(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state) { - /* First attempt to dispatch regular tasks. */ - dispatch_tasks(state, algorithm_state); - - /* Attempt to dispatch actor tasks. */ - auto it = algorithm_state->actors_with_pending_tasks.begin(); - while (it != algorithm_state->actors_with_pending_tasks.end()) { - // We cannot short-circuit and exit here if there are no resources - // available because actor methods may require 0 CPUs. - - /* We increment the iterator ahead of time because the call to - * dispatch_actor_task may invalidate the current iterator. */ - ActorID actor_id = *it; - it++; - /* Dispatch tasks for the current actor. */ - dispatch_actor_task(state, algorithm_state, actor_id); - } -} - -/** - * A helper function to allocate a queue entry for a task specification and - * push it onto a generic queue. - * - * @param state The state of the local scheduler. - * @param task_queue A pointer to a task queue. NOTE: Because we are using - * utlist.h, we must pass in a pointer to the queue we want to append - * to. If we passed in the queue itself and the queue was empty, this - * would append the task to a queue that we don't have a reference to. - * @param task_entry A pointer to the task entry to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return A reference to the entry in the queue that was pushed. - */ -std::list::iterator queue_task( - LocalSchedulerState *state, - std::list *task_queue, - TaskExecutionSpec &task_entry, - bool from_global_scheduler) { - /* The task has been added to a local scheduler queue. Write the entry in the - * task table to notify others that we have queued it. */ - if (state->db != NULL) { - Task *task = - Task_alloc(task_entry, TaskStatus::QUEUED, get_db_client_id(state->db)); - if (from_global_scheduler) { - /* If the task is from the global scheduler, it's already been added to - * the task table, so just update the entry. */ - task_table_update(state->db, task, NULL, NULL, NULL); - } else { - /* Otherwise, this is the first time the task has been seen in the system - * (unless it's a resubmission of a previous task), so add the entry. */ - task_table_add_task(state->db, task, NULL, NULL, NULL); - } - } - - /* Copy the spec and add it to the task queue. The allocated spec will be - * freed when it is assigned to a worker. */ - TaskExecutionSpec copy = TaskExecutionSpec(&task_entry); - task_queue->push_back(std::move(copy)); - /* Since we just queued the task, we can get a reference to it by going to - * the last element in the queue. */ - auto it = task_queue->end(); - --it; - - return it; -} - -/** - * Queue a task whose dependencies are missing. When the task's object - * dependencies become available, the task will be moved to the dispatch queue. - * If we have a connection to a plasma manager, begin trying to fetch the - * dependencies. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_waiting_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - /* For actor tasks, do not queue tasks that have already been executed. */ - auto spec = execution_spec.Spec(); - if (!TaskSpec_actor_id(spec).is_nil()) { - auto entry = - algorithm_state->local_actor_infos.find(TaskSpec_actor_id(spec)); - if (entry != algorithm_state->local_actor_infos.end()) { - /* Find the highest task counter with the same handle ID as the task to - * queue. */ - auto &task_counters = entry->second.task_counters; - auto task_counter = task_counters.find(TaskSpec_actor_handle_id(spec)); - if (task_counter != task_counters.end() && - TaskSpec_actor_counter(spec) < task_counter->second) { - /* If the task to queue has a lower task counter, do not queue it. */ - RAY_LOG(INFO) << "A task that has already been executed has been " - << "resubmitted, so we are ignoring it. This should only " - << "happen during reconstruction."; - return; - } - } - } - - RAY_LOG(DEBUG) << "Queueing task in waiting queue"; - auto it = queue_task(state, algorithm_state->waiting_task_queue, - execution_spec, from_global_scheduler); - fetch_missing_dependencies(state, algorithm_state, it); -} - -/** - * Queue a task whose dependencies are ready. When the task reaches the front - * of the dispatch queue and workers are available, it will be assigned. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_dispatch_task(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - RAY_LOG(DEBUG) << "Queueing task in dispatch queue"; - TaskSpec *spec = execution_spec.Spec(); - if (TaskSpec_is_actor_task(spec)) { - queue_actor_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } else { - queue_task(state, algorithm_state->dispatch_task_queue, execution_spec, - from_global_scheduler); - } -} - -/** - * Add the task to the proper local scheduler queue. This assumes that the - * scheduling decision to place the task on this node has already been made, - * whether locally or by the global scheduler. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to queue. - * @param from_global_scheduler Whether or not the task was from a global - * scheduler. If false, the task was submitted by a worker. - * @return Void. - */ -void queue_task_locally(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - bool from_global_scheduler) { - if (can_run(algorithm_state, execution_spec)) { - /* Dependencies are ready, so push the task to the dispatch queue. */ - queue_dispatch_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } else { - /* Dependencies are not ready, so push the task to the waiting queue. */ - queue_waiting_task(state, algorithm_state, execution_spec, - from_global_scheduler); - } -} - -void give_task_to_local_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::SCHEDULED); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - RAY_CHECK(TaskSpec_is_actor_task(spec)); - - ActorID actor_id = TaskSpec_actor_id(spec); - - if (state->actor_mapping.count(actor_id) == 0) { - // Process the actor task submission again. This will cache the task - // locally until a new actor creation notification is broadcast. We will - // attempt to reissue the actor creation tasks for all cached actor tasks - // in rerun_actor_creation_tasks_timeout_handler. - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - return; - } - - DBClientID remote_local_scheduler_id = - state->actor_mapping[actor_id].local_scheduler_id; - - // TODO(rkn): db_client_table_cache_get is a blocking call, is this a - // performance issue? - DBClient remote_local_scheduler = - db_client_table_cache_get(state->db, remote_local_scheduler_id); - - // Check if the local scheduler that we're assigning this task to is still - // alive. - if (remote_local_scheduler.is_alive) { - // The local scheduler is still alive, which means that perhaps it hasn't - // subscribed to the appropriate channel yet, so retrying should suffice. - // This should be rare. - give_task_to_local_scheduler( - state, state->algorithm_state, *execution_spec, - state->actor_mapping[actor_id].local_scheduler_id); - } else { - // The local scheduler is dead, so we will need to recreate the actor by - // invoking reconstruction. - RAY_LOG(INFO) << "Local scheduler " << remote_local_scheduler_id - << " that was running actor " << actor_id << " died."; - RAY_CHECK(state->actor_mapping.count(actor_id) == 1); - // Update the actor mapping. - state->actor_mapping.erase(actor_id); - // Process the actor task submission again. This will cache the task - // locally until a new actor creation notification is broadcast. We will - // attempt to reissue the actor creation tasks for all cached actor tasks - // in rerun_actor_creation_tasks_timeout_handler. - handle_actor_task_submitted(state, state->algorithm_state, *execution_spec); - } -} - -/** - * Give a task directly to another local scheduler. This is currently only used - * for assigning actor tasks to the local scheduler responsible for that actor. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to schedule. - * @param local_scheduler_id The ID of the local scheduler to give the task to. - * @return Void. - */ -void give_task_to_local_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec, - DBClientID local_scheduler_id) { - if (local_scheduler_id == get_db_client_id(state->db)) { - RAY_LOG(WARNING) << "Local scheduler is trying to assign a task to itself."; - } - RAY_CHECK(state->db != NULL); - /* Assign the task to the relevant local scheduler. */ - RAY_CHECK(state->config.global_scheduler_exists); - Task *task = - Task_alloc(execution_spec, TaskStatus::SCHEDULED, local_scheduler_id); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = give_task_to_local_scheduler_retry, - }; - - task_table_add_task(state->db, task, &retryInfo, NULL, state); -} - -void give_task_to_global_scheduler_retry(UniqueID id, - void *user_context, - void *user_data) { - LocalSchedulerState *state = (LocalSchedulerState *) user_context; - Task *task = (Task *) user_data; - RAY_CHECK(Task_state(task) == TaskStatus::WAITING); - - TaskExecutionSpec *execution_spec = Task_task_execution_spec(task); - TaskSpec *spec = execution_spec->Spec(); - RAY_CHECK(!TaskSpec_is_actor_task(spec)); - - give_task_to_global_scheduler(state, state->algorithm_state, *execution_spec); -} - -/** - * Give a task to the global scheduler to schedule. - * - * @param state The scheduler state. - * @param algorithm_state The scheduling algorithm state. - * @param spec The task specification to schedule. - * @return Void. - */ -void give_task_to_global_scheduler(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - if (state->db == NULL || !state->config.global_scheduler_exists) { - /* A global scheduler is not available, so queue the task locally. */ - queue_task_locally(state, algorithm_state, execution_spec, false); - return; - } - /* Pass on the task to the global scheduler. */ - RAY_CHECK(state->config.global_scheduler_exists); - Task *task = Task_alloc(execution_spec, TaskStatus::WAITING, - get_db_client_id(state->db)); - RAY_CHECK(state->db != NULL); - auto retryInfo = RetryInfo{ - .num_retries = 0, // This value is unused. - .timeout = 0, // This value is unused. - .fail_callback = give_task_to_global_scheduler_retry, - }; - task_table_add_task(state->db, task, &retryInfo, NULL, state); -} - -bool resource_constraints_satisfied(LocalSchedulerState *state, - TaskSpec *spec) { - /* At the local scheduler, if required resource vector exceeds either static - * or dynamic resource vector, the resource constraint is not satisfied. */ - for (auto const &resource_pair : TaskSpec_get_required_resources(spec)) { - double required_resource = resource_pair.second; - if (required_resource > state->static_resources[resource_pair.first] || - required_resource > state->dynamic_resources[resource_pair.first]) { - return false; - } - } - - if (TaskSpec_is_actor_creation_task(spec) && - state->static_resources["CPU"] != 0) { - return false; - } - - return true; -} - -void handle_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - /* TODO(atumanov): if static is satisfied and local objects ready, but dynamic - * resource is currently unavailable, then consider queueing task locally and - * recheck dynamic next time. */ - - // If this task's constraints are satisfied, dependencies are available - // locally, and there is an available worker, then enqueue the task in the - // dispatch queue and trigger task dispatch. Otherwise, pass the task along to - // the global scheduler if there is one. - // Note that actor creation tasks automatically go to the global scheduler. - // See https://github.com/ray-project/ray/issues/1756 for more discussion. - // This is a hack to improve actor load balancing (and to prevent the scenario - // where all actors are started locally). - if (resource_constraints_satisfied(state, spec) && - (algorithm_state->available_workers.size() > 0) && - can_run(algorithm_state, execution_spec) && - !TaskSpec_is_actor_creation_task(spec)) { - queue_dispatch_task(state, algorithm_state, execution_spec, false); - } else { - /* Give the task to the global scheduler to schedule, if it exists. */ - give_task_to_global_scheduler(state, algorithm_state, execution_spec); - } - - /* Try to dispatch tasks, since we may have added one to the queue. */ - dispatch_tasks(state, algorithm_state); -} - -void handle_actor_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *task_spec = execution_spec.Spec(); - RAY_CHECK(TaskSpec_is_actor_task(task_spec)); - ActorID actor_id = TaskSpec_actor_id(task_spec); - - if (state->actor_mapping.count(actor_id) == 0) { - // Create a copy of the task to write to the task table. - Task *task = Task_alloc( - task_spec, execution_spec.SpecSize(), TaskStatus::ACTOR_CACHED, - get_db_client_id(state->db), execution_spec.ExecutionDependencies()); - - /* Add this task to a queue of tasks that have been submitted but the local - * scheduler doesn't know which actor is responsible for them. These tasks - * will be resubmitted (internally by the local scheduler) whenever a new - * actor notification arrives. NOTE(swang): These tasks have not yet been - * added to the task table. */ - TaskExecutionSpec task_entry = TaskExecutionSpec(&execution_spec); - algorithm_state->cached_submitted_actor_tasks.push_back( - std::move(task_entry)); - - // Even if the task can't be assigned to a worker yet, we should still write - // it to the task table. TODO(rkn): There's no need to do this more than - // once, and we could run into problems if we have very large numbers of - // tasks in this cache. - task_table_add_task(state->db, task, NULL, NULL, NULL); - - return; - } - - if (state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)) { - /* This local scheduler is responsible for the actor, so handle the task - * locally. */ - queue_task_locally(state, algorithm_state, execution_spec, false); - /* Attempt to dispatch tasks to this actor. */ - dispatch_actor_task(state, algorithm_state, actor_id); - } else { - /* This local scheduler is not responsible for the task, so find the local - * scheduler that is responsible for this actor and assign the task directly - * to that local scheduler. */ - give_task_to_local_scheduler( - state, algorithm_state, execution_spec, - state->actor_mapping[actor_id].local_scheduler_id); - } -} - -void handle_actor_creation_notification( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - int num_cached_actor_tasks = - algorithm_state->cached_submitted_actor_tasks.size(); - - for (int i = 0; i < num_cached_actor_tasks; ++i) { - TaskExecutionSpec &task = algorithm_state->cached_submitted_actor_tasks[i]; - /* Note that handle_actor_task_submitted may append the spec to the end of - * the cached_submitted_actor_tasks array. */ - handle_actor_task_submitted(state, algorithm_state, task); - } - /* Remove all the tasks that were resubmitted. This does not erase the tasks - * that were newly appended to the cached_submitted_actor_tasks array. */ - auto begin = algorithm_state->cached_submitted_actor_tasks.begin(); - algorithm_state->cached_submitted_actor_tasks.erase( - begin, begin + num_cached_actor_tasks); -} - -void handle_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - /* This callback handles tasks that were assigned to this local scheduler by - * the global scheduler, so we can safely assert that there is a connection to - * the database. */ - RAY_CHECK(state->db != NULL); - RAY_CHECK(state->config.global_scheduler_exists); - - // Currently, the global scheduler will never assign a task to a local - // scheduler that has 0 CPUs. - RAY_CHECK(state->static_resources["CPU"] != 0); - - // Push the task to the appropriate queue. - queue_task_locally(state, algorithm_state, execution_spec, true); - dispatch_tasks(state, algorithm_state); -} - -void handle_actor_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec) { - TaskSpec *spec = execution_spec.Spec(); - /* This callback handles tasks that were assigned to this local scheduler by - * the global scheduler or by other workers, so we can safely assert that - * there is a connection to the database. */ - RAY_CHECK(state->db != NULL); - RAY_CHECK(state->config.global_scheduler_exists); - /* Check that the task is meant to run on an actor that this local scheduler - * is responsible for. */ - RAY_CHECK(TaskSpec_is_actor_task(spec)); - ActorID actor_id = TaskSpec_actor_id(spec); - if (state->actor_mapping.count(actor_id) == 1) { - RAY_CHECK(state->actor_mapping[actor_id].local_scheduler_id == - get_db_client_id(state->db)); - } else { - /* This means that an actor has been assigned to this local scheduler, and a - * task for that actor has been received by this local scheduler, but this - * local scheduler has not yet processed the notification about the actor - * creation. This may be possible though should be very uncommon. If it does - * happen, it's ok. */ - RAY_LOG(INFO) << "handle_actor_task_scheduled called on local scheduler " - << "but the corresponding actor_map_entry is not present. " - << "This should be rare."; - } - /* Push the task to the appropriate queue. */ - queue_task_locally(state, algorithm_state, execution_spec, true); - dispatch_actor_task(state, algorithm_state, actor_id); -} - -void handle_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - RAY_CHECK(worker->task_in_progress == NULL); - /* Check that the worker isn't in the pool of available workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->available_workers, worker)); - - /* Check that the worker isn't in the list of blocked workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* If the worker was executing a task, it must have finished, so remove it - * from the list of executing workers. If the worker is connecting for the - * first time, it will not be in the list of executing workers. */ - remove_worker_from_vector(algorithm_state->executing_workers, worker); - /* Double check that we successfully removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Add worker to the list of available workers. */ - algorithm_state->available_workers.push_back(worker); - - /* Try to dispatch tasks. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Make sure this is not an actor. */ - RAY_CHECK(worker->actor_id.is_nil()); - - /* Make sure that we remove the worker at most once. */ - int num_times_removed = 0; - - /* Remove the worker from available workers, if it's there. */ - bool removed_from_available = - remove_worker_from_vector(algorithm_state->available_workers, worker); - num_times_removed += removed_from_available; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->available_workers, worker)); - - /* Remove the worker from executing workers, if it's there. */ - bool removed_from_executing = - remove_worker_from_vector(algorithm_state->executing_workers, worker); - num_times_removed += removed_from_executing; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Remove the worker from blocked workers, if it's there. */ - bool removed_from_blocked = - remove_worker_from_vector(algorithm_state->blocked_workers, worker); - num_times_removed += removed_from_blocked; - /* Double check that we actually removed the worker. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* Make sure we removed the worker at most once. */ - RAY_CHECK(num_times_removed <= 1); - - /* Attempt to dispatch some tasks because some resources may have freed up. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_actor_worker_disconnect(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker, - bool cleanup) { - /* Fail all in progress or queued tasks of the actor. */ - if (!cleanup) { - if (state->db != NULL) { - actor_table_mark_removed(state->db, worker->actor_id); - } - - if (worker->task_in_progress != NULL) { - finish_killed_task(state, - *Task_task_execution_spec(worker->task_in_progress)); - } - - state->removed_actors.insert(worker->actor_id); - - RAY_CHECK(algorithm_state->local_actor_infos.count(worker->actor_id) != 0); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(worker->actor_id)->second; - for (auto &task : *entry.task_queue) { - finish_killed_task(state, task); - } - } - - remove_actor(algorithm_state, worker->actor_id); - - /* Attempt to dispatch some tasks because some resources may have freed up. */ - dispatch_all_tasks(state, algorithm_state); - - /* Start a worker to replace the removed actor's worker and replenish the - * worker pool. */ - start_worker(state); -} - -/* NOTE(swang): For tasks that saved a checkpoint, we should consider - * overwriting the result table entries for the current task frontier to - * avoid duplicate task submissions during reconstruction. */ -void handle_actor_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - ActorID actor_id = worker->actor_id; - RAY_CHECK(!actor_id.is_nil()); - /* Get the actor info for this worker. */ - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) == 1); - LocalActorInfo &entry = - algorithm_state->local_actor_infos.find(actor_id)->second; - RAY_CHECK(worker == entry.worker); - RAY_CHECK(!entry.worker_available); - /* If an actor task was assigned, mark returned dummy object as locally - * available. This is not added to the object table, so the update will be - * invisible to other nodes. */ - /* NOTE(swang): These objects are never cleaned up. We should consider - * removing the objects, e.g., when an actor is terminated. */ - if (!entry.execution_dependency.is_nil()) { - handle_object_available(state, algorithm_state, entry.execution_dependency); - } - /* Unset the fields indicating an assigned task. */ - entry.worker_available = true; - /* Assign new tasks if possible. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Find the worker in the list of executing workers. */ - RAY_CHECK( - remove_worker_from_vector(algorithm_state->executing_workers, worker)); - - /* Check that the worker isn't in the list of blocked workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->blocked_workers, worker)); - - /* Add the worker to the list of blocked workers. */ - algorithm_state->blocked_workers.push_back(worker); - - /* Try to dispatch tasks, since we may have freed up some resources. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_actor_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* The actor case doesn't use equivalents of the blocked_workers and - * executing_workers lists. Are these necessary? */ - /* Try to dispatch tasks, since we may have freed up some resources. */ - dispatch_all_tasks(state, algorithm_state); -} - -void handle_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) { - /* Find the worker in the list of blocked workers. */ - RAY_CHECK( - remove_worker_from_vector(algorithm_state->blocked_workers, worker)); - - /* Check that the worker isn't in the list of executing workers. */ - RAY_CHECK(!worker_in_vector(algorithm_state->executing_workers, worker)); - - /* Add the worker to the list of executing workers. */ - algorithm_state->executing_workers.push_back(worker); -} - -void handle_actor_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker) {} - -void handle_object_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ObjectID object_id) { - auto object_entry_it = algorithm_state->remote_objects.find(object_id); - - ObjectEntry entry; - /* Get the entry for this object from the active fetch request, or allocate - * one if needed. */ - if (object_entry_it != algorithm_state->remote_objects.end()) { - /* Remove the object from the active fetch requests. */ - entry = object_entry_it->second; - algorithm_state->remote_objects.erase(object_id); - } - - /* Add the entry to the set of locally available objects. */ - RAY_CHECK(algorithm_state->local_objects.count(object_id) == 0); - algorithm_state->local_objects[object_id] = entry; - - if (!entry.dependent_tasks.empty()) { - /* Out of the tasks that were dependent on this object, if they are now - * ready to run, move them to the dispatch queue. */ - for (auto &it : entry.dependent_tasks) { - if (can_run(algorithm_state, *it)) { - if (TaskSpec_is_actor_task(it->Spec())) { - insert_actor_task_queue(state, algorithm_state, std::move(*it)); - } else { - algorithm_state->dispatch_task_queue->push_back(std::move(*it)); - } - /* Remove the entry with a matching TaskSpec pointer from the waiting - * queue, but do not free the task spec. */ - algorithm_state->waiting_task_queue->erase(it); - } - } - /* Try to dispatch tasks, since we may have added some from the waiting - * queue. */ - dispatch_all_tasks(state, algorithm_state); - /* Clean up the records for dependent tasks. */ - entry.dependent_tasks.clear(); - } -} - -void handle_object_removed(LocalSchedulerState *state, - ObjectID removed_object_id) { - /* Remove the object from the set of locally available objects. */ - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - - RAY_CHECK(algorithm_state->local_objects.count(removed_object_id) == 1); - algorithm_state->local_objects.erase(removed_object_id); - - /* Track queued tasks that were dependent on this object. - * NOTE: Since objects often get removed in batches (e.g., during eviction), - * we may end up iterating through the queues many times in a row. If this - * turns out to be a bottleneck, consider tracking dependencies even for - * tasks in the dispatch queue, or batching object notifications. */ - /* Track the dependency for tasks that were in the dispatch queue. Remove - * these tasks from the dispatch queue and push them to the waiting queue. */ - for (auto it = algorithm_state->dispatch_task_queue->begin(); - it != algorithm_state->dispatch_task_queue->end();) { - if (it->DependsOn(removed_object_id)) { - /* This task was dependent on the removed object. */ - RAY_LOG(DEBUG) << "Moved task from dispatch queue back to waiting queue"; - algorithm_state->waiting_task_queue->push_back(std::move(*it)); - /* Remove the task from the dispatch queue, but do not free the task - * spec. */ - it = algorithm_state->dispatch_task_queue->erase(it); - } else { - /* The task can still run, so continue to the next task. */ - ++it; - } - } - - std::vector empty_actor_queues; - for (auto it = algorithm_state->actors_with_pending_tasks.begin(); - it != algorithm_state->actors_with_pending_tasks.end(); it++) { - auto actor_info = algorithm_state->local_actor_infos[*it]; - for (auto queue_it = actor_info.task_queue->begin(); - queue_it != actor_info.task_queue->end();) { - if (queue_it->DependsOn(removed_object_id)) { - /* This task was dependent on the removed object. */ - RAY_LOG(DEBUG) << "Moved task from actor dispatch queue back to " - << "waiting queue"; - algorithm_state->waiting_task_queue->push_back(std::move(*queue_it)); - /* Remove the task from the dispatch queue, but do not free the task - * spec. */ - queue_it = actor_info.task_queue->erase(queue_it); - if (actor_info.task_queue->size() == 0) { - empty_actor_queues.push_back(*it); - } - } else { - ++queue_it; - } - } - } - for (auto actor_id : empty_actor_queues) { - algorithm_state->actors_with_pending_tasks.erase(actor_id); - } - - /* Track the dependency for tasks that are in the waiting queue, including - * those that were just moved from the dispatch queue. */ - for (auto it = algorithm_state->waiting_task_queue->begin(); - it != algorithm_state->waiting_task_queue->end(); ++it) { - int64_t num_dependencies = it->NumDependencies(); - for (int64_t i = 0; i < num_dependencies; ++i) { - int count = it->DependencyIdCount(i); - for (int j = 0; j < count; ++j) { - ObjectID dependency_id = it->DependencyId(i, j); - if (dependency_id == removed_object_id) { - /* Do not request a transfer from other plasma managers if this is an - * execution dependency. */ - bool request_transfer = it->IsStaticDependency(i); - fetch_missing_dependency(state, algorithm_state, it, - removed_object_id.to_plasma_id(), - request_transfer); - } - } - } - } -} - -void handle_driver_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - WorkerID driver_id) { - /* Loop over fetch requests. This must be done before we clean up the waiting - * task queue and the dispatch task queue because this map contains iterators - * for those lists, which will be invalidated when we clean up those lists.*/ - for (auto it = algorithm_state->remote_objects.begin(); - it != algorithm_state->remote_objects.end();) { - /* Loop over the tasks that are waiting for this object and remove the tasks - * for the removed driver. */ - auto task_it_it = it->second.dependent_tasks.begin(); - while (task_it_it != it->second.dependent_tasks.end()) { - /* If the dependent task was a task for the removed driver, remove it from - * this vector. */ - TaskSpec *spec = (*task_it_it)->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - task_it_it = it->second.dependent_tasks.erase(task_it_it); - } else { - task_it_it++; - } - } - /* If there are no more dependent tasks for this object, then remove the - * ObjectEntry. */ - if (it->second.dependent_tasks.size() == 0) { - it = algorithm_state->remote_objects.erase(it); - } else { - it++; - } - } - - /* Remove this driver's tasks from the waiting task queue. */ - auto it = algorithm_state->waiting_task_queue->begin(); - while (it != algorithm_state->waiting_task_queue->end()) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->waiting_task_queue->erase(it); - } else { - it++; - } - } - - /* Remove this driver's tasks from the dispatch task queue. */ - it = algorithm_state->dispatch_task_queue->begin(); - while (it != algorithm_state->dispatch_task_queue->end()) { - TaskSpec *spec = it->Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->dispatch_task_queue->erase(it); - } else { - it++; - } - } - - // Remove this driver's tasks from the cached actor tasks. Note that this loop - // could be very slow if the vector of cached actor tasks is very long. - for (auto it = algorithm_state->cached_submitted_actor_tasks.begin(); - it != algorithm_state->cached_submitted_actor_tasks.end();) { - TaskSpec *spec = (*it).Spec(); - if (TaskSpec_driver_id(spec) == driver_id) { - it = algorithm_state->cached_submitted_actor_tasks.erase(it); - } else { - ++it; - } - } - - /* TODO(rkn): Should we clean up the actor data structures? */ -} - -int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state) { - return algorithm_state->waiting_task_queue->size(); -} - -int num_dispatch_tasks(SchedulingAlgorithmState *algorithm_state) { - return algorithm_state->dispatch_task_queue->size(); -} - -void print_worker_info(const char *message, - SchedulingAlgorithmState *algorithm_state) { - RAY_LOG(DEBUG) << message << ": " << algorithm_state->available_workers.size() - << " available, " << algorithm_state->executing_workers.size() - << " executing, " << algorithm_state->blocked_workers.size() - << " blocked"; -} - -std::unordered_map get_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - return algorithm_state->local_actor_infos[actor_id].task_counters; -} - -void set_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &task_counters) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - /* Overwrite the current task counters for the actor. This is necessary - * during reconstruction when resuming from a checkpoint so that we can - * resume the task frontier at the time that the checkpoint was saved. */ - auto &entry = algorithm_state->local_actor_infos[actor_id]; - entry.task_counters = task_counters; - - /* Filter out tasks for the actor that were submitted earlier than the new - * task counter. These represent tasks that executed before the actor's - * resumed checkpoint, and therefore should not be re-executed. */ - for (auto it = entry.task_queue->begin(); it != entry.task_queue->end();) { - /* Filter out duplicate tasks for the actor that are runnable. */ - TaskSpec *pending_task_spec = it->Spec(); - ActorHandleID handle_id = TaskSpec_actor_handle_id(pending_task_spec); - auto task_counter = entry.task_counters.find(handle_id); - if (task_counter != entry.task_counters.end() && - TaskSpec_actor_counter(pending_task_spec) < task_counter->second) { - /* If the task's counter is less than the highest count for that handle, - * then remove it from the actor's runnable queue. */ - it = entry.task_queue->erase(it); - } else { - it++; - } - } - for (auto it = algorithm_state->waiting_task_queue->begin(); - it != algorithm_state->waiting_task_queue->end();) { - /* Filter out duplicate tasks for the actor that are waiting on a missing - * dependency. */ - TaskSpec *spec = it->Spec(); - if (TaskSpec_actor_id(spec) == actor_id && - TaskSpec_actor_counter(spec) < - entry.task_counters[TaskSpec_actor_handle_id(spec)]) { - /* If the waiting task is for the same actor and its task counter is less - * than the highest count for that handle, then clear its object - * dependencies and remove it from the queue. */ - clear_missing_dependencies(algorithm_state, it); - it = algorithm_state->waiting_task_queue->erase(it); - } else { - it++; - } - } -} - -std::unordered_map get_actor_frontier( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - return algorithm_state->local_actor_infos[actor_id].frontier_dependencies; -} - -void set_actor_frontier( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &frontier_dependencies) { - RAY_CHECK(algorithm_state->local_actor_infos.count(actor_id) != 0); - auto entry = algorithm_state->local_actor_infos[actor_id]; - entry.frontier_dependencies = frontier_dependencies; - for (auto frontier_dependency : entry.frontier_dependencies) { - if (algorithm_state->local_objects.count(frontier_dependency.second) == 0) { - handle_object_available(state, algorithm_state, - frontier_dependency.second); - } - } -} diff --git a/src/local_scheduler/local_scheduler_algorithm.h b/src/local_scheduler/local_scheduler_algorithm.h deleted file mode 100644 index 9238d5db58e5..000000000000 --- a/src/local_scheduler/local_scheduler_algorithm.h +++ /dev/null @@ -1,438 +0,0 @@ -#ifndef LOCAL_SCHEDULER_ALGORITHM_H -#define LOCAL_SCHEDULER_ALGORITHM_H - -#include "local_scheduler_shared.h" -#include "common/task.h" -#include "state/local_scheduler_table.h" - -/* ==== The scheduling algorithm ==== - * - * This file contains declaration for all functions and data structures - * that need to be provided if you want to implement a new algorithms - * for the local scheduler. - * - */ - -/** - * Initialize the scheduler state. - * - * @return State managed by the scheduling algorithm. - */ -SchedulingAlgorithmState *SchedulingAlgorithmState_init(void); - -/** - * Free the scheduler state. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return Void. - */ -void SchedulingAlgorithmState_free(SchedulingAlgorithmState *algorithm_state); - -/** - * - */ -void provide_scheduler_info(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerInfo *info); - -/** - * This function will be called when a new task is submitted by a worker for - * execution. The task will either be: - * 1. Put into the waiting queue, where it will wait for its dependencies to - * become available. - * 2. Put into the dispatch queue, where it will wait for an available worker. - * 3. Given to the global scheduler to be scheduled. - * - * Currently, the local scheduler policy is to keep the task if its - * dependencies are ready and there is an available worker. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is submitted by the worker. - * @return Void. - */ -void handle_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This version of handle_task_submitted is used when the task being submitted - * is a method of an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is submitted by the worker. - * @return Void. - */ -void handle_actor_task_submitted(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function will be called when the local scheduler receives a notification - * about the creation of a new actor. This can be used by the scheduling - * algorithm to resubmit cached actor tasks. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor being created. - * @return Void. - */ -void handle_actor_creation_notification( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * This function will be called when a task is assigned by the global scheduler - * for execution on this local scheduler. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is assigned by the global scheduler. - * @return Void. - */ -void handle_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function will be called when an actor task is assigned by the global - * scheduler or by another local scheduler for execution on this local - * scheduler. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param task Task that is assigned by the global scheduler. - * @return Void. - */ -void handle_actor_task_scheduled(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - TaskExecutionSpec &execution_spec); - -/** - * This function is called if a new object becomes available in the local - * plasma store. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param object_id ID of the object that became available. - * @return Void. - */ -void handle_object_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ObjectID object_id); - -/** - * This function is called if an object is removed from the local plasma store. - * - * @param state The state of the local scheduler. - * @param object_id ID of the object that was removed. - * @return Void. - */ -void handle_object_removed(LocalSchedulerState *state, ObjectID object_id); - -/** - * This function is called when a new worker becomes available. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is available. - * @return Void. - */ -void handle_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when a worker is removed. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is removed. - * @return Void. - */ -void handle_worker_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This version of handle_worker_available is called whenever the worker that is - * available is running an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is available. - * @return Void. - */ -void handle_actor_worker_available(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * Handle the fact that a new worker is available for running an actor. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor running on the worker. - * @param initial_execution_dependency The dummy object ID of the actor - * creation task. - * @param worker The worker that was converted to an actor. - * @return Void. - */ -void handle_convert_worker_to_actor( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - const ActorID &actor_id, - const ObjectID &initial_execution_dependency, - LocalSchedulerClient *worker); - -/** - * Handle the fact that a worker running an actor has disconnected. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that was disconnected. - * @param cleanup Whether the disconnect was during cleanup. - * @return Void. - */ -void handle_actor_worker_disconnect(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker, - bool cleanup); - -/** - * This function is called when a worker that was executing a task becomes - * blocked on an object that isn't available locally yet. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is blocked. - * @return Void. - */ -void handle_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when an actor that was executing a task becomes - * blocked on an object that isn't available locally yet. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is blocked. - * @return Void. - */ -void handle_actor_worker_blocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when a worker that was blocked on an object that - * wasn't available locally yet becomes unblocked. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is now unblocked. - * @return Void. - */ -void handle_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * This function is called when an actor that was blocked on an object that - * wasn't available locally yet becomes unblocked. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param worker The worker that is now unblocked. - * @return Void. - */ -void handle_actor_worker_unblocked(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - LocalSchedulerClient *worker); - -/** - * Process the fact that a driver has been removed. This will remove all of the - * tasks for that driver from the scheduling algorithm's internal data - * structures. - * - * @param state The state of the local scheduler. - * @param algorithm_state State maintained by the scheduling algorithm. - * @param driver_id The ID of the driver that was removed. - * @return Void. - */ -void handle_driver_removed(LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - WorkerID driver_id); - -/** - * This function fetches queued task's missing object dependencies. It is - * called every local_scheduler_fetch_timeout_milliseconds. - * - * @param loop The local scheduler's event loop. - * @param id The ID of the timer that triggers this function. - * @param context The function's context. - * @return An integer representing the time interval in seconds before the - * next invocation of the function. - */ -int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context); - -/** - * This function initiates reconstruction for task's missing object - * dependencies. It is called every - * local_scheduler_reconstruction_timeout_milliseconds, but it may not initiate - * reconstruction for every missing object. - * - * @param loop The local scheduler's event loop. - * @param id The ID of the timer that triggers this function. - * @param context The function's context. - * @return An integer representing the time interval in seconds before the - * next invocation of the function. - */ -int reconstruct_object_timeout_handler(event_loop *loop, - timer_id id, - void *context); - -/// This function initiates reconstruction for the actor creation tasks -/// corresponding to the actor tasks cached in the local scheduler. -/// -/// \param loop The local scheduler's event loop. -/// \param id The ID of the timer that triggers this function. -/// \param context The function's context. -/// \return An integer representing the time interval in seconds before the -/// next invocation of the function. -int rerun_actor_creation_tasks_timeout_handler(event_loop *loop, - timer_id id, - void *context); - -/** - * Check whether an object, including actor dummy objects, is locally - * available. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param object_id The ID of the object to check for. - * @return A bool representing whether the object is locally available. - */ -bool object_locally_available(SchedulingAlgorithmState *algorithm_state, - ObjectID object_id); - -/// Spill some tasks back to the global scheduler. This function implements the -/// spillback policy. -/// -/// @param state The scheduler state. -/// @return Void. -void spillback_tasks_handler(LocalSchedulerState *state); - -/** - * A helper function to print debug information about the current state and - * number of workers. - * - * @param message A message to identify the log message. - * @param algorithm_state State maintained by the scheduling algorithm. - * @return Void. - */ -void print_worker_info(const char *message, - SchedulingAlgorithmState *algorithm_state); - -/* - * The actor frontier consists of the number of tasks executed so far and the - * execution dependencies required by the current runnable tasks, according to - * the actor's local scheduler. Since an actor may have multiple handles, the - * tasks submitted to the actor form a DAG, where nodes are tasks and edges are - * execution dependencies. The frontier is a cut across this DAG. The number of - * tasks so far is the number of nodes included in the DAG root's partition. - * - * The actor gets the current frontier of tasks from the local scheduler during - * a checkpoint save, so that it can save the point in the actor's lifetime at - * which the checkpoint was taken. If the actor later resumes from that - * checkpoint, the actor can set the current frontier of tasks in the local - * scheduler so that the same frontier of tasks can be made runnable again - * during reconstruction, and so that we do not duplicate execution of tasks - * that already executed before the checkpoint. - */ - -/** - * Get the number of tasks, per actor handle, that have been executed on an - * actor so far. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @return A map from handle ID to the number of tasks submitted by that handle - * that have executed so far. - */ -std::unordered_map get_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * Set the number of tasks, per actor handle, that have been executed on an - * actor so far. All previous counts will be overwritten. Tasks that are - * waiting or runnable on the local scheduler that have a lower task count will - * be discarded, so that we don't duplicate execution. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @param task_counters A map from handle ID to the number of tasks submitted - * by that handle that have executed so far. - * @return Void. - */ -void set_actor_task_counters( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &task_counters); - -/** - * Get the actor's frontier of task dependencies. - * NOTE(swang): The returned frontier only includes handles known by the local - * scheduler. It does not include handles for which the local scheduler has not - * seen a runnable task yet. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @return A map from handle ID to execution dependency for the earliest - * runnable task submitted through that handle. - */ -std::unordered_map get_actor_frontier( - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id); - -/** - * Set the actor's frontier of task dependencies. The previous frontier will be - * overwritten. Any tasks that have an execution dependency on the new frontier - * (and that have all other dependencies fulfilled) will become runnable. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @param actor_id The ID of the actor whose task counters are returned. - * @param frontier_dependencies A map from handle ID to execution dependency - * for the earliest runnable task submitted through that handle. - * @return Void. - */ -void set_actor_frontier( - LocalSchedulerState *state, - SchedulingAlgorithmState *algorithm_state, - ActorID actor_id, - const std::unordered_map &frontier_dependencies); - -/** The following methods are for testing purposes only. */ -#ifdef LOCAL_SCHEDULER_TEST -/** - * Get the number of tasks currently waiting for object dependencies to become - * available locally. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return The number of tasks queued. - */ -int num_waiting_tasks(SchedulingAlgorithmState *algorithm_state); - -/** - * Get the number of tasks currently waiting for a worker to become available. - * - * @param algorithm_state State maintained by the scheduling algorithm. - * @return The number of tasks queued. - */ -int num_dispatch_tasks(SchedulingAlgorithmState *algorithm_state); -#endif - -#endif /* LOCAL_SCHEDULER_ALGORITHM_H */ diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc deleted file mode 100644 index 09bda7f5bd8d..000000000000 --- a/src/local_scheduler/local_scheduler_client.cc +++ /dev/null @@ -1,385 +0,0 @@ -#include "local_scheduler_client.h" - -#include "common_protocol.h" -#include "format/local_scheduler_generated.h" -#include "ray/raylet/format/node_manager_generated.h" - -#include "common/io.h" -#include "common/task.h" -#include -#include -#include - -using MessageType = ray::local_scheduler::protocol::MessageType; - -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, - const UniqueID &client_id, - bool is_worker, - const JobID &driver_id, - bool use_raylet, - const Language &language) { - LocalSchedulerConnection *result = new LocalSchedulerConnection(); - result->use_raylet = use_raylet; - result->conn = connect_ipc_sock_retry(local_scheduler_socket, -1, -1); - - /* Register with the local scheduler. - * NOTE(swang): If the local scheduler exits and we are registered as a - * worker, we will get killed. */ - flatbuffers::FlatBufferBuilder fbb; - if (use_raylet) { - auto message = ray::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), - to_flatbuf(fbb, driver_id), language); - fbb.Finish(message); - } else { - auto message = ray::local_scheduler::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), - to_flatbuf(fbb, driver_id)); - fbb.Finish(message); - } - /* Register the process ID with the local scheduler. */ - int success = write_message( - result->conn, static_cast(MessageType::RegisterClientRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &result->write_mutex); - RAY_CHECK(success == 0) << "Unable to register worker with local scheduler"; - - return result; -} - -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn) { - close(conn->conn); - delete conn; -} - -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateDisconnectClient(fbb); - fbb.Finish(message); - if (conn->use_raylet) { - write_message(conn->conn, static_cast( - MessageType::IntentionalDisconnectClient), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - } else { - write_message(conn->conn, - static_cast(MessageType::DisconnectClient), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - } -} - -void local_scheduler_log_event(LocalSchedulerConnection *conn, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto key_string = fbb.CreateString((char *) key, key_length); - auto value_string = fbb.CreateString((char *) value, value_length); - auto message = ray::local_scheduler::protocol::CreateEventLogMessage( - fbb, key_string, value_string, timestamp); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::EventLogMessage), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_submit(LocalSchedulerConnection *conn, - const TaskExecutionSpec &execution_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies = - to_flatbuf(fbb, execution_spec.ExecutionDependencies()); - auto task_spec = - fbb.CreateString(reinterpret_cast(execution_spec.Spec()), - execution_spec.SpecSize()); - auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies, task_spec); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::SubmitTask), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_submit_raylet( - LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto execution_dependencies_message = to_flatbuf(fbb, execution_dependencies); - auto message = ray::local_scheduler::protocol::CreateSubmitTaskRequest( - fbb, execution_dependencies_message, task_spec.ToFlatbuffer(fbb)); - fbb.Finish(message); - write_message(conn->conn, static_cast(MessageType::SubmitTask), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn, - int64_t *task_size) { - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, static_cast(MessageType::GetTask), 0, - NULL, &conn->write_mutex); - /* Receive a task from the local scheduler. This will block until the local - * scheduler gives this client a task. */ - read_message(conn->conn, &type, &reply_size, &reply); - } - if (type == static_cast(CommonMessageType::DISCONNECT_CLIENT)) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(static_cast(type) == MessageType::ExecuteTask); - - /* Parse the flatbuffer object. */ - auto reply_message = - flatbuffers::GetRoot(reply); - - /* Create a copy of the task spec so we can free the reply. */ - *task_size = reply_message->task_spec()->size(); - TaskSpec *data = (TaskSpec *) reply_message->task_spec()->data(); - TaskSpec *spec = TaskSpec_copy(data, *task_size); - - // Set the GPU IDs for this task. We only do this for non-actor tasks because - // for actors the GPUs are associated with the actor itself and not with the - // actor methods. Note that this also processes GPUs for actor creation tasks. - if (!TaskSpec_is_actor_task(spec)) { - conn->gpu_ids.clear(); - for (size_t i = 0; i < reply_message->gpu_ids()->size(); ++i) { - conn->gpu_ids.push_back(reply_message->gpu_ids()->Get(i)); - } - } - - /* Free the original message from the local scheduler. */ - free(reply); - /* Return the copy of the task spec and pass ownership to the caller. */ - return spec; -} - -// This is temporarily duplicated from local_scheduler_get_task while we have -// the raylet and non-raylet code paths. -TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn, - int64_t *task_size) { - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, static_cast(MessageType::GetTask), 0, - NULL, &conn->write_mutex); - // Receive a task from the local scheduler. This will block until the local - // scheduler gives this client a task. - read_message(conn->conn, &type, &reply_size, &reply); - } - if (type == static_cast(CommonMessageType::DISCONNECT_CLIENT)) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(type == static_cast(MessageType::ExecuteTask)); - - // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply); - - // Create a copy of the task spec so we can free the reply. - *task_size = reply_message->task_spec()->size(); - const TaskSpec *data = - reinterpret_cast(reply_message->task_spec()->data()); - TaskSpec *spec = TaskSpec_copy(const_cast(data), *task_size); - - // Set the resource IDs for this task. - conn->resource_ids_.clear(); - for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); - ++i) { - auto const &fractional_resource_ids = - reply_message->fractional_resource_ids()->Get(i); - auto &acquired_resources = conn->resource_ids_[string_from_flatbuf( - *fractional_resource_ids->resource_name())]; - - size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); - size_t num_resource_fractions = - fractional_resource_ids->resource_fractions()->size(); - RAY_CHECK(num_resource_ids == num_resource_fractions); - RAY_CHECK(num_resource_ids > 0); - for (size_t j = 0; j < num_resource_ids; ++j) { - int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); - double resource_fraction = - fractional_resource_ids->resource_fractions()->Get(j); - if (num_resource_ids > 1) { - int64_t whole_fraction = resource_fraction; - RAY_CHECK(whole_fraction == resource_fraction); - } - acquired_resources.push_back( - std::make_pair(resource_id, resource_fraction)); - } - } - - // Free the original message from the local scheduler. - free(reply); - // Return the copy of the task spec and pass ownership to the caller. - return spec; -} - -void local_scheduler_task_done(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::TaskDone), 0, - NULL, &conn->write_mutex); -} - -void local_scheduler_reconstruct_objects( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only) { - flatbuffers::FlatBufferBuilder fbb; - auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = ray::local_scheduler::protocol::CreateReconstructObjects( - fbb, object_ids_message, fetch_only); - fbb.Finish(message); - write_message(conn->conn, - static_cast(MessageType::ReconstructObjects), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - /* TODO(swang): Propagate the error. */ -} - -void local_scheduler_log_message(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::EventLogMessage), - 0, NULL, &conn->write_mutex); -} - -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn) { - write_message(conn->conn, static_cast(MessageType::NotifyUnblocked), - 0, NULL, &conn->write_mutex); -} - -void local_scheduler_put_object(LocalSchedulerConnection *conn, - TaskID task_id, - ObjectID object_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreatePutObject( - fbb, to_flatbuf(fbb, task_id), to_flatbuf(fbb, object_id)); - fbb.Finish(message); - - write_message(conn->conn, static_cast(MessageType::PutObject), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -const std::vector local_scheduler_get_actor_frontier( - LocalSchedulerConnection *conn, - ActorID actor_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::local_scheduler::protocol::CreateGetActorFrontierRequest( - fbb, to_flatbuf(fbb, actor_id)); - fbb.Finish(message); - int64_t type; - std::vector reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, - static_cast(MessageType::GetActorFrontierRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - - read_vector(conn->conn, &type, reply); - } - if (static_cast(type) == - CommonMessageType::DISCONNECT_CLIENT) { - RAY_LOG(DEBUG) << "Exiting because local scheduler closed connection."; - exit(1); - } - RAY_CHECK(static_cast(type) == - MessageType::GetActorFrontierReply); - return reply; -} - -void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, - const std::vector &frontier) { - write_message(conn->conn, static_cast(MessageType::SetActorFrontier), - frontier.size(), const_cast(frontier.data()), - &conn->write_mutex); -} - -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - int num_returns, - int64_t timeout_milliseconds, - bool wait_local) { - // Write request. - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, - wait_local); - fbb.Finish(message); - int64_t type; - int64_t reply_size; - uint8_t *reply; - { - std::unique_lock guard(conn->mutex); - write_message(conn->conn, - static_cast(ray::protocol::MessageType::WaitRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - // Read result. - read_message(conn->conn, &type, &reply_size, &reply); - } - RAY_CHECK(static_cast(type) == - ray::protocol::MessageType::WaitReply); - auto reply_message = flatbuffers::GetRoot(reply); - // Convert result. - std::pair, std::vector> result; - auto found = reply_message->found(); - for (uint i = 0; i < found->size(); i++) { - ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); - result.first.push_back(object_id); - } - auto remaining = reply_message->remaining(); - for (uint i = 0; i < remaining->size(); i++) { - ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); - result.second.push_back(object_id); - } - /* Free the original message from the local scheduler. */ - free(reply); - return result; -} - -void local_scheduler_push_error(LocalSchedulerConnection *conn, - const JobID &job_id, - const std::string &type, - const std::string &error_message, - double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), - fbb.CreateString(error_message), timestamp); - fbb.Finish(message); - - write_message(conn->conn, static_cast( - ray::protocol::MessageType::PushErrorRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_push_profile_events( - LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events) { - flatbuffers::FlatBufferBuilder fbb; - - auto message = CreateProfileTableData(fbb, &profile_events); - fbb.Finish(message); - - write_message(conn->conn, - static_cast( - ray::protocol::MessageType::PushProfileEventsRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); -} - -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool local_only) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest( - fbb, local_only, to_flatbuf(fbb, object_ids)); - fbb.Finish(message); - - int success = write_message( - conn->conn, - static_cast( - ray::protocol::MessageType::FreeObjectsInObjectStoreRequest), - fbb.GetSize(), fbb.GetBufferPointer(), &conn->write_mutex); - RAY_CHECK(success == 0) << "Failed to write message to raylet."; -} diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h deleted file mode 100644 index bb4fdb345896..000000000000 --- a/src/local_scheduler/local_scheduler_client.h +++ /dev/null @@ -1,260 +0,0 @@ -#ifndef LOCAL_SCHEDULER_CLIENT_H -#define LOCAL_SCHEDULER_CLIENT_H - -#include - -#include "common/task.h" -#include "local_scheduler_shared.h" -#include "ray/raylet/task_spec.h" - -struct LocalSchedulerConnection { - /// True if we should use the raylet code path and false otherwise. - bool use_raylet; - /** File descriptor of the Unix domain socket that connects to local - * scheduler. */ - int conn; - /** The IDs of the GPUs that this client can use. NOTE(rkn): This is only used - * by legacy Ray and will be deprecated. */ - std::vector gpu_ids; - /// A map from resource name to the resource IDs that are currently reserved - /// for this worker. Each pair consists of the resource ID and the fraction - /// of that resource allocated for this worker. - std::unordered_map>> - resource_ids_; - /// A mutex to protect stateful operations of the local scheduler client. - std::mutex mutex; - /// A mutext to protect write operations of the local scheduler client. - std::mutex write_mutex; -}; - -/** - * Connect to the local scheduler. - * - * @param local_scheduler_socket The name of the socket to use to connect to the - * local scheduler. - * @param worker_id A unique ID to represent the worker. - * @param is_worker Whether this client is a worker. If it is a worker, an - * additional message will be sent to register as one. - * @param driver_id The ID of the driver. This is non-nil if the client is a - * driver. - * @param use_raylet True if we should use the raylet code path and false - * otherwise. - * @return The connection information. - */ -LocalSchedulerConnection *LocalSchedulerConnection_init( - const char *local_scheduler_socket, - const UniqueID &worker_id, - bool is_worker, - const JobID &driver_id, - bool use_raylet, - const Language &language); - -/** - * Disconnect from the local scheduler. - * - * @param conn Local scheduler connection information returned by - * LocalSchedulerConnection_init. - * @return Void. - */ -void LocalSchedulerConnection_free(LocalSchedulerConnection *conn); - -/** - * Submit a task to the local scheduler. - * - * @param conn The connection information. - * @param execution_spec The execution spec for the task to submit. - * @return Void. - */ -void local_scheduler_submit(LocalSchedulerConnection *conn, - const TaskExecutionSpec &execution_spec); - -/// Submit a task using the raylet code path. -/// -/// \param The connection information. -/// \param The execution dependencies. -/// \param The task specification. -/// \return Void. -void local_scheduler_submit_raylet( - LocalSchedulerConnection *conn, - const std::vector &execution_dependencies, - const ray::raylet::TaskSpecification &task_spec); - -/** - * Notify the local scheduler that this client is disconnecting gracefully. This - * is used by actors to exit gracefully so that the local scheduler doesn't - * propagate an error message to the driver. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_disconnect_client(LocalSchedulerConnection *conn); - -/** - * Log an event to the event log. This will call RPUSH key value. We use RPUSH - * instead of SET so that it is possible to flush the log multiple times with - * the same key (for example the key might be shared across logging calls in the - * same task on a worker). - * - * @param conn The connection information. - * @param key The key to store the event in. - * @param key_length The length of the key. - * @param value The value to store. - * @param value_length The length of the value. - * @param timestamp The time that the event is logged. - * @return Void. - */ -void local_scheduler_log_event(LocalSchedulerConnection *conn, - uint8_t *key, - int64_t key_length, - uint8_t *value, - int64_t value_length, - double timestamp); - -/** - * Get next task for this client. This will block until the scheduler assigns - * a task to this worker. This allocates and returns a task, and so the task - * must be freed by the caller. - * - * @todo When does this actually get freed? - * - * @param conn The connection information. - * @param task_size A pointer to fill out with the task size. - * @return The address of the assigned task. - */ -TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn, - int64_t *task_size); - -/// Get next task for this client. This will block until the scheduler assigns -/// a task to this worker. This allocates and returns a task, and so the task -/// must be freed by the caller. -/// -/// \param conn The connection information. -/// \param task_size A pointer to fill out with the task size. -/// \return The address of the assigned task. -TaskSpec *local_scheduler_get_task_raylet(LocalSchedulerConnection *conn, - int64_t *task_size); - -/** - * Tell the local scheduler that the client has finished executing a task. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_task_done(LocalSchedulerConnection *conn); - -/** - * Tell the local scheduler to reconstruct or fetch objects. - * - * @param conn The connection information. - * @param object_ids The IDs of the objects to reconstruct. - * @param fetch_only Only fetch objects, do not reconstruct them. - * @return Void. - */ -void local_scheduler_reconstruct_objects( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool fetch_only = false); - -/** - * Send a log message to the local scheduler. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_log_message(LocalSchedulerConnection *conn); - -/** - * Notify the local scheduler that this client (worker) is no longer blocked. - * - * @param conn The connection information. - * @return Void. - */ -void local_scheduler_notify_unblocked(LocalSchedulerConnection *conn); - -/** - * Record the mapping from object ID to task ID for put events. - * - * @param conn The connection information. - * @param task_id The ID of the task that called put. - * @param object_id The ID of the object being stored. - * @return Void. - */ -void local_scheduler_put_object(LocalSchedulerConnection *conn, - TaskID task_id, - ObjectID object_id); - -/** - * Get an actor's current task frontier. - * - * @param conn The connection information. - * @param actor_id The ID of the actor whose frontier is returned. - * @return A byte vector that can be traversed as an ActorFrontier flatbuffer. - */ -const std::vector local_scheduler_get_actor_frontier( - LocalSchedulerConnection *conn, - ActorID actor_id); - -/** - * Set an actor's current task frontier. - * - * @param conn The connection information. - * @param frontier An ActorFrontier flatbuffer to set the frontier to. - * @return Void. - */ -void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn, - const std::vector &frontier); - -/// Wait for the given objects until timeout expires or num_return objects are -/// found. -/// -/// \param conn The connection information. -/// \param object_ids The objects to wait for. -/// \param num_returns The number of objects to wait for. -/// \param timeout_milliseconds Duration, in milliseconds, to wait before -/// returning. -/// \param wait_local Whether to wait for objects to appear on this node. -/// \return A pair with the first element containing the object ids that were -/// found, and the second element the objects that were not found. -std::pair, std::vector> local_scheduler_wait( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - int num_returns, - int64_t timeout_milliseconds, - bool wait_local); - -/// Push an error to the relevant driver. -/// -/// \param conn The connection information. -/// \param The ID of the job that the error is for. -/// \param The type of the error. -/// \param The error message. -/// \param The timestamp of the error. -/// \return Void. -void local_scheduler_push_error(LocalSchedulerConnection *conn, - const JobID &job_id, - const std::string &type, - const std::string &error_message, - double timestamp); - -/// Store some profile events in the GCS. -/// -/// \param conn The connection information. -/// \param profile_events A batch of profiling event information. -/// \return Void. -void local_scheduler_push_profile_events( - LocalSchedulerConnection *conn, - const ProfileTableDataT &profile_events); - -/// Free a list of objects from object stores. -/// -/// \param conn The connection information. -/// \param object_ids A list of ObjectsIDs to be deleted. -/// \param local_only Whether keep this request with local object store -/// or send it to all the object stores. -/// \return Void. -void local_scheduler_free_objects_in_object_store( - LocalSchedulerConnection *conn, - const std::vector &object_ids, - bool local_only); - -#endif diff --git a/src/local_scheduler/local_scheduler_shared.h b/src/local_scheduler/local_scheduler_shared.h deleted file mode 100644 index 572f14a6fdf7..000000000000 --- a/src/local_scheduler/local_scheduler_shared.h +++ /dev/null @@ -1,137 +0,0 @@ -#ifndef LOCAL_SCHEDULER_SHARED_H -#define LOCAL_SCHEDULER_SHARED_H - -#include "common/task.h" -#include "common/state/table.h" -#include "common/state/db.h" -#include "plasma/client.h" -#include "ray/gcs/client.h" - -#include -#include -#include -#include - -/** This struct is used to maintain a mapping from actor IDs to the ID of the - * local scheduler that is responsible for the actor. */ -struct ActorMapEntry { - /** The ID of the driver that created the actor. */ - WorkerID driver_id; - /** The ID of the local scheduler that is responsible for the actor. */ - DBClientID local_scheduler_id; -}; - -/** Internal state of the scheduling algorithm. */ -typedef struct SchedulingAlgorithmState SchedulingAlgorithmState; - -struct LocalSchedulerClient; - -/** A struct storing the configuration state of the local scheduler. This should - * consist of values that don't change over the lifetime of the local - * scheduler. */ -typedef struct { - /** The script to use when starting a new worker. */ - const char **start_worker_command; - /** Whether there is a global scheduler. */ - bool global_scheduler_exists; -} local_scheduler_config; - -/** The state of the local scheduler. */ -struct LocalSchedulerState { - /** The configuration for the local scheduler. */ - local_scheduler_config config; - /** The local scheduler event loop. */ - event_loop *loop; - /** List of workers available to this node. This is used to free the worker - * structs when we free the scheduler state and also to access the worker - * structs in the tests. */ - std::list workers; - /** A set of driver IDs corresponding to drivers that have been removed. This - * is used to make sure we don't execute any tasks belong to dead drivers. */ - std::unordered_set removed_drivers; - /** A set of actors IDs corresponding to local actors that have been removed. - * This ensures we can reject any tasks destined for dead actors. */ - std::unordered_set removed_actors; - /** List of the process IDs for child processes (workers) started by the - * local scheduler that have not sent a REGISTER_PID message yet. */ - std::vector child_pids; - /** A hash table mapping actor IDs to the db_client_id of the local scheduler - * that is responsible for the actor. */ - std::unordered_map actor_mapping; - /** The handle to the database. */ - DBHandle *db; - /** The Plasma client. */ - plasma::PlasmaClient *plasma_conn; - /** State for the scheduling algorithm. */ - SchedulingAlgorithmState *algorithm_state; - /** Input buffer, used for reading input in process_message to avoid - * allocation for each call to process_message. */ - std::vector input_buffer; - /** Vector of static attributes associated with the node owned by this local - * scheduler. */ - std::unordered_map static_resources; - /** Vector of dynamic attributes associated with the node owned by this local - * scheduler. */ - std::unordered_map dynamic_resources; - /** The IDs of the available GPUs. There is redundancy here in that - * available_gpus.size() == dynamic_resources[ResourceIndex_GPU] should - * always be true. */ - std::vector available_gpus; - /** The time (in milliseconds since the Unix epoch) when the most recent - * heartbeat was sent. */ - int64_t previous_heartbeat_time; -}; - -/** Contains all information associated with a local scheduler client. */ -struct LocalSchedulerClient { - /** The socket used to communicate with the client. */ - int sock; - /** True if the client has registered and false otherwise. */ - bool registered; - /** True if the client has sent a disconnect message to the local scheduler - * and false otherwise. If this is true, then the local scheduler will not - * propagate an error message to the driver when the client exits. */ - bool disconnected; - /** True if the client is a worker and false if it is a driver. */ - bool is_worker; - /** The worker ID if the client is a worker and the driver ID if the client is - * a driver. */ - WorkerID client_id; - /** A pointer to the task object that is currently running on this client. If - * no task is running on the worker, this will be NULL. This is used to - * update the task table. */ - Task *task_in_progress; - /** An array of resource counts currently in use by the worker. */ - std::unordered_map resources_in_use; - /** A vector of the IDs of the GPUs that the worker is currently using. If the - * worker is an actor, this will be constant throughout the lifetime of the - * actor (and will be equal to the number of GPUs requested by the actor). If - * the worker is not an actor, this will be constant for the duration of a - * task and will have length equal to the number of GPUs requested by the - * task (in particular it will not change if the task blocks). */ - std::vector gpus_in_use; - /** A flag to indicate whether this worker is currently blocking on an - * object(s) that isn't available locally yet. */ - bool is_blocked; - /** The process ID of the client. If this is set to zero, the client has not - * yet registered a process ID. */ - pid_t pid; - /** Whether the client is a child process of the local scheduler. */ - bool is_child; - /** The ID of the actor on this worker. If there is no actor running on this - * worker, this should be NIL_ACTOR_ID. */ - ActorID actor_id; - /** A pointer to the local scheduler state. */ - LocalSchedulerState *local_scheduler_state; -}; - -/** - * Free the local scheduler state. This disconnects all clients and notifies - * the global scheduler of the local scheduler's exit. - * - * @param state The state to free. - * @return Void - */ -void LocalSchedulerState_free(LocalSchedulerState *state); - -#endif /* LOCAL_SCHEDULER_SHARED_H */ diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc deleted file mode 100644 index b155ea9494c8..000000000000 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ /dev/null @@ -1,704 +0,0 @@ -#include "greatest.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "common.h" -#include "test/test_common.h" -#include "test/example_task.h" -#include "event_loop.h" -#include "io.h" -#include "task.h" -#include "state/object_table.h" -#include "state/task_table.h" -#include "state/redis.h" - -#include "local_scheduler_shared.h" -#include "local_scheduler.h" -#include "local_scheduler_algorithm.h" -#include "local_scheduler_client.h" - -SUITE(local_scheduler_tests); - -TaskBuilder *g_task_builder = NULL; - -const char *plasma_store_socket_name = "/tmp/plasma_store_socket_1"; -const char *plasma_manager_socket_name_format = "/tmp/plasma_manager_socket_%d"; -const char *local_scheduler_socket_name_format = - "/tmp/local_scheduler_socket_%d"; - -int64_t timeout_handler(event_loop *loop, int64_t id, void *context) { - event_loop_stop(loop); - return EVENT_LOOP_TIMER_DONE; -} - -typedef struct { - /** A socket to mock the Plasma manager. Clients (such as workers) that - * connect to this file descriptor must be accepted. */ - int plasma_manager_fd; - /** A socket to communicate with the Plasma store. */ - int plasma_store_fd; - /** Local scheduler's socket for IPC requests. */ - int local_scheduler_fd; - /** Local scheduler's local scheduler state. */ - LocalSchedulerState *local_scheduler_state; - /** Local scheduler's event loop. */ - event_loop *loop; - /** Number of local scheduler client connections, or mock workers. */ - int num_local_scheduler_conns; - /** Local scheduler client connections. */ - LocalSchedulerConnection **conns; -} LocalSchedulerMock; - -/** - * Register clients of the local scheduler. This function is started in a - * separate thread so enable a blocking call to register the clients. - */ -static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) { - for (int i = 0; i < num_mock_workers; ++i) { - new_client_connection(mock->loop, mock->local_scheduler_fd, - (void *) mock->local_scheduler_state, 0); - LocalSchedulerClient *worker = mock->local_scheduler_state->workers.back(); - process_message(mock->local_scheduler_state->loop, worker->sock, worker, 0); - } -} - -LocalSchedulerMock *LocalSchedulerMock_init(int num_workers, - int num_mock_workers) { - const char *node_ip_address = "127.0.0.1"; - const char *redis_addr = node_ip_address; - int redis_port = 6379; - std::unordered_map static_resource_conf; - static_resource_conf["CPU"] = INT16_MAX; - static_resource_conf["GPU"] = 0; - LocalSchedulerMock *mock = - (LocalSchedulerMock *) malloc(sizeof(LocalSchedulerMock)); - memset(mock, 0, sizeof(LocalSchedulerMock)); - mock->loop = event_loop_create(); - /* Bind to the local scheduler port and initialize the local scheduler. */ - std::string plasma_manager_socket_name = bind_ipc_sock_retry( - plasma_manager_socket_name_format, &mock->plasma_manager_fd); - mock->plasma_store_fd = - connect_ipc_sock_retry(plasma_store_socket_name, 5, 100); - std::string local_scheduler_socket_name = bind_ipc_sock_retry( - local_scheduler_socket_name_format, &mock->local_scheduler_fd); - RAY_CHECK(mock->plasma_store_fd >= 0 && mock->local_scheduler_fd >= 0); - - /* Construct worker command */ - std::stringstream worker_command_ss; - worker_command_ss << "python ../python/ray/workers/default_worker.py" - << " --node-ip-address=" << node_ip_address - << " --object-store-name=" << plasma_store_socket_name - << " --object-store-manager-name=" - << plasma_manager_socket_name - << " --local-scheduler-name=" << local_scheduler_socket_name - << " --redis-address=" << redis_addr << ":" << redis_port; - std::string worker_command = worker_command_ss.str(); - - mock->local_scheduler_state = LocalSchedulerState_init( - "127.0.0.1", mock->loop, redis_addr, redis_port, - local_scheduler_socket_name.c_str(), plasma_store_socket_name, - plasma_manager_socket_name.c_str(), NULL, false, static_resource_conf, - worker_command.c_str(), num_workers); - - /* Accept the workers as clients to the plasma manager. */ - for (int i = 0; i < num_workers; ++i) { - accept_client(mock->plasma_manager_fd); - } - - /* Connect a local scheduler client. */ - mock->num_local_scheduler_conns = num_mock_workers; - mock->conns = (LocalSchedulerConnection **) malloc( - sizeof(LocalSchedulerConnection *) * num_mock_workers); - - std::thread background_thread = - std::thread(register_clients, num_mock_workers, mock); - - for (int i = 0; i < num_mock_workers; ++i) { - mock->conns[i] = LocalSchedulerConnection_init( - local_scheduler_socket_name.c_str(), WorkerID::nil(), true, - JobID::nil(), false, Language::PYTHON); - } - - background_thread.join(); - - return mock; -} - -void LocalSchedulerMock_free(LocalSchedulerMock *mock) { - /* Disconnect clients. */ - for (int i = 0; i < mock->num_local_scheduler_conns; ++i) { - LocalSchedulerConnection_free(mock->conns[i]); - } - free(mock->conns); - - /* Kill all the workers and run the event loop again so that the task table - * updates propagate and the tasks in progress are freed. */ - while (mock->local_scheduler_state->workers.size() > 0) { - LocalSchedulerClient *worker = mock->local_scheduler_state->workers.front(); - kill_worker(mock->local_scheduler_state, worker, true, false); - } - event_loop_add_timer(mock->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(mock->loop); - - /* This also frees mock->loop. */ - LocalSchedulerState_free(mock->local_scheduler_state); - close(mock->plasma_store_fd); - close(mock->plasma_manager_fd); - free(mock); -} - -void reset_worker(LocalSchedulerMock *mock, LocalSchedulerClient *worker) { - if (worker->task_in_progress) { - Task_free(worker->task_in_progress); - worker->task_in_progress = NULL; - } -} - -/** - * Test that object reconstruction gets called. If a task gets submitted, - * assigned to a worker, and then reconstruction is triggered for its return - * value, the task should get assigned to a worker again. - */ -TEST object_reconstruction_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - - /* Create a task with zero dependencies and one return value. */ - TaskExecutionSpec execution_spec = example_task_execution_spec(0, 1); - TaskSpec *spec = execution_spec.Spec(); - int64_t task_size = execution_spec.SpecSize(); - ObjectID return_id = TaskSpec_return(spec, 0); - - /* Add an empty object table entry for the object we want to reconstruct, to - * simulate it having been created and evicted. */ - const char *client_id = "clientid"; - /* Lookup the shard locations for the object table. */ - std::vector db_shards_addresses; - std::vector db_shards_ports; - redisContext *context = redisConnect("127.0.0.1", 6379); - get_redis_shards(context, db_shards_addresses, db_shards_ports); - redisFree(context); - /* There should only be one shard, so we can safely add the empty object - * table entry to the first one. */ - ASSERT(db_shards_addresses.size() == 1); - context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); - redisReply *reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.data(), - sizeof(return_id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); - freeReplyObject(reply); - reply = (redisReply *) redisCommand(context, "RAY.OBJECT_TABLE_REMOVE %b %s", - return_id.data(), sizeof(return_id), - client_id); - freeReplyObject(reply); - redisFree(context); - - pid_t pid = fork(); - if (pid == 0) { - /* Make sure we receive the task twice. First from the initial submission, - * and second from the reconstruct request. */ - int64_t task_assigned_size; - local_scheduler_submit(worker, execution_spec); - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - ASSERT_EQ(memcmp(task_assigned, spec, task_size), 0); - ASSERT_EQ(task_assigned_size, task_size); - int64_t reconstruct_task_size; - TaskSpec *reconstruct_task = - local_scheduler_get_task(worker, &reconstruct_task_size); - ASSERT_EQ(memcmp(reconstruct_task, spec, task_size), 0); - ASSERT_EQ(reconstruct_task_size, task_size); - /* Clean up. */ - free(reconstruct_task); - free(task_assigned); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Run the event loop. NOTE: OSX appears to require the parent process to - * listen for events on the open file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Set the task's status to TaskStatus::DONE to prevent the race condition - * that would suppress object reconstruction. */ - Task *task = Task_alloc( - execution_spec, TaskStatus::DONE, - get_db_client_id(local_scheduler->local_scheduler_state->db)); - task_table_add_task(local_scheduler->local_scheduler_state->db, task, NULL, - NULL, NULL); - - /* Trigger reconstruction, and run the event loop again. */ - ObjectID return_id = TaskSpec_return(spec, 0); - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -/** - * Test that object reconstruction gets recursively called. In a chain of - * tasks, if all inputs are lost, then reconstruction of the final object - * should trigger reconstruction of all previous tasks in the lineage. - */ -TEST object_reconstruction_recursive_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - /* Create a chain of tasks, each one dependent on the one before it. Mark - * each object as available so that tasks will run immediately. */ - const int NUM_TASKS = 10; - std::vector specs; - specs.push_back(example_task_execution_spec(0, 1)); - for (int i = 1; i < NUM_TASKS; ++i) { - ObjectID arg_id = TaskSpec_return(specs[i - 1].Spec(), 0); - specs.push_back(example_task_execution_spec_with_args(1, 1, &arg_id)); - } - /* Lookup the shard locations for the object table. */ - const char *client_id = "clientid"; - std::vector db_shards_addresses; - std::vector db_shards_ports; - redisContext *context = redisConnect("127.0.0.1", 6379); - get_redis_shards(context, db_shards_addresses, db_shards_ports); - redisFree(context); - /* There should only be one shard, so we can safely add the empty object - * table entry to the first one. */ - ASSERT(db_shards_addresses.size() == 1); - context = redisConnect(db_shards_addresses[0].c_str(), db_shards_ports[0]); - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - redisReply *reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.data(), - sizeof(return_id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); - freeReplyObject(reply); - reply = (redisReply *) redisCommand( - context, "RAY.OBJECT_TABLE_REMOVE %b %s", return_id.data(), - sizeof(return_id), client_id); - freeReplyObject(reply); - } - redisFree(context); - - pid_t pid = fork(); - if (pid == 0) { - /* Submit the tasks, and make sure each one gets assigned to a worker. */ - for (int i = 0; i < NUM_TASKS; ++i) { - local_scheduler_submit(worker, specs[i]); - } - /* Make sure we receive each task from the initial submission. */ - for (int i = 0; i < NUM_TASKS; ++i) { - int64_t task_size; - TaskSpec *task_assigned = local_scheduler_get_task(worker, &task_size); - ASSERT_EQ(memcmp(task_assigned, specs[i].Spec(), specs[i].SpecSize()), 0); - ASSERT_EQ(task_size, specs[i].SpecSize()); - free(task_assigned); - } - /* Check that the workers receive all tasks in the final return object's - * lineage during reconstruction. */ - for (int i = 0; i < NUM_TASKS; ++i) { - int64_t task_assigned_size; - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - for (auto it = specs.begin(); it != specs.end(); it++) { - if (memcmp(task_assigned, it->Spec(), task_assigned_size) == 0) { - specs.erase(it); - break; - } - } - free(task_assigned); - } - ASSERT(specs.size() == 0); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Simulate each task putting its return values in the object store so that - * the next task can run. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_available( - local_scheduler->local_scheduler_state, - local_scheduler->local_scheduler_state->algorithm_state, return_id); - } - /* Run the event loop. All tasks should now be dispatched. NOTE: OSX - * appears to require the parent process to listen for events on the open - * file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Set the final task's status to TaskStatus::DONE to prevent the race - * condition that would suppress object reconstruction. */ - Task *last_task = Task_alloc( - specs[NUM_TASKS - 1], TaskStatus::DONE, - get_db_client_id(local_scheduler->local_scheduler_state->db)); - task_table_add_task(local_scheduler->local_scheduler_state->db, last_task, - NULL, NULL, NULL); - /* Simulate eviction of the objects, so that reconstruction is required. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_removed(local_scheduler->local_scheduler_state, return_id); - } - /* Trigger reconstruction for the last object. */ - ObjectID return_id = TaskSpec_return(specs[NUM_TASKS - 1].Spec(), 0); - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - /* Run the event loop again. All tasks should be resubmitted. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Simulate each task putting its return values in the object store so that - * the next task can run. */ - for (int i = 0; i < NUM_TASKS; ++i) { - ObjectID return_id = TaskSpec_return(specs[i].Spec(), 0); - handle_object_available( - local_scheduler->local_scheduler_state, - local_scheduler->local_scheduler_state->algorithm_state, return_id); - } - /* Run the event loop again. All tasks should be dispatched again. */ - event_loop_add_timer(local_scheduler->loop, 500, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - specs.clear(); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -/** - * Test that object reconstruction gets suppressed when there is a location - * listed for the object in the object table. - */ -TaskExecutionSpec *object_reconstruction_suppression_spec; - -void object_reconstruction_suppression_callback(ObjectID object_id, - bool success, - void *user_context) { - RAY_CHECK(success); - /* Submit the task after adding the object to the object table. */ - LocalSchedulerConnection *worker = (LocalSchedulerConnection *) user_context; - local_scheduler_submit(worker, *object_reconstruction_suppression_spec); -} - -TEST object_reconstruction_suppression_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerConnection *worker = local_scheduler->conns[0]; - - TaskExecutionSpec execution_spec = example_task_execution_spec(0, 1); - object_reconstruction_suppression_spec = &execution_spec; - ObjectID return_id = - TaskSpec_return(object_reconstruction_suppression_spec->Spec(), 0); - pid_t pid = fork(); - if (pid == 0) { - /* Make sure we receive the task once. This will block until the - * object_table_add callback completes. */ - int64_t task_assigned_size; - TaskSpec *task_assigned = - local_scheduler_get_task(worker, &task_assigned_size); - ASSERT_EQ( - memcmp(task_assigned, object_reconstruction_suppression_spec->Spec(), - object_reconstruction_suppression_spec->SpecSize()), - 0); - /* Trigger a reconstruction. We will check that no tasks get queued as a - * result of this line in the event loop process. */ - local_scheduler_reconstruct_objects(worker, - std::vector({return_id})); - /* Clean up. */ - free(task_assigned); - LocalSchedulerMock_free(local_scheduler); - exit(0); - } else { - /* Connect a plasma manager client so we can call object_table_add. */ - std::vector db_connect_args; - db_connect_args.push_back("manager_address"); - db_connect_args.push_back("127.0.0.1:12346"); - DBHandle *db = db_connect(std::string("127.0.0.1"), 6379, "plasma_manager", - "127.0.0.1", db_connect_args); - db_attach(db, local_scheduler->loop, false); - /* Add the object to the object table. */ - object_table_add(db, return_id, 1, (unsigned char *) NIL_DIGEST, NULL, - object_reconstruction_suppression_callback, - (void *) worker); - /* Run the event loop. NOTE: OSX appears to require the parent process to - * listen for events on the open file descriptors. */ - event_loop_add_timer(local_scheduler->loop, 1000, - (event_loop_timer_handler) timeout_handler, NULL); - event_loop_run(local_scheduler->loop); - /* Wait for the child process to exit and check that there are no tasks - * left in the local scheduler's task queue. Then, clean up. */ - wait(NULL); - ASSERT_EQ(num_waiting_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - ASSERT_EQ(num_dispatch_tasks( - local_scheduler->local_scheduler_state->algorithm_state), - 0); - db_disconnect(db); - LocalSchedulerMock_free(local_scheduler); - PASS(); - } -} - -TEST task_dependency_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerState *state = local_scheduler->local_scheduler_state; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - /* Get the first worker. */ - LocalSchedulerClient *worker = state->workers.front(); - TaskExecutionSpec execution_spec = example_task_execution_spec(1, 1); - TaskSpec *spec = execution_spec.Spec(); - ObjectID oid = TaskSpec_arg_id(spec, 0, 0); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted, but the input and workers are not available. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted and a worker is available, but the input is not. */ - handle_object_removed(state, oid); - handle_task_submitted(state, algorithm_state, execution_spec); - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets assigned. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the dispatch queue if the task is - * submitted and the input is available, but no worker is available yet. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* If an object gets removed, check the first scenario again, where the task - * gets queued in the waiting task if the task is submitted and a worker is - * available, but the input is not. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* If the input is removed while a task is in the dispatch queue, the task - * gets moved back to the waiting queue. */ - handle_object_removed(state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once the input is available, the task gets moved back to the dispatch - * queue. */ - handle_object_available(state, algorithm_state, oid); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -TEST task_multi_dependency_test(void) { - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(0, 1); - LocalSchedulerState *state = local_scheduler->local_scheduler_state; - SchedulingAlgorithmState *algorithm_state = state->algorithm_state; - /* Get the first worker. */ - LocalSchedulerClient *worker = state->workers.front(); - TaskExecutionSpec execution_spec = example_task_execution_spec(2, 1); - TaskSpec *spec = execution_spec.Spec(); - ObjectID oid1 = TaskSpec_arg_id(spec, 0, 0); - ObjectID oid2 = TaskSpec_arg_id(spec, 1, 0); - - /* Check that the task gets queued in the waiting queue if the task is - * submitted, but the inputs and workers are not available. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if only one input becomes - * available. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once all inputs are available, the task is moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - /* Check that the task gets queued in the dispatch queue if the task is - * submitted and the inputs are available, but no worker is available yet. */ - handle_task_submitted(state, algorithm_state, execution_spec); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* If any input is removed while a task is in the dispatch queue, the task - * gets moved back to the waiting queue. */ - handle_object_removed(state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - handle_object_removed(state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if only one input becomes - * available. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if the one input is - * unavailable again. */ - handle_object_removed(state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Check that the task stays in the waiting queue if the other input becomes - * available. */ - handle_object_available(state, algorithm_state, oid1); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 1); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - /* Once all inputs are available, the task is moved to the dispatch queue. */ - handle_object_available(state, algorithm_state, oid2); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 1); - /* Once a worker is available, the task gets assigned. */ - handle_worker_available(state, algorithm_state, worker); - ASSERT_EQ(num_waiting_tasks(algorithm_state), 0); - ASSERT_EQ(num_dispatch_tasks(algorithm_state), 0); - reset_worker(local_scheduler, worker); - - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -TEST start_kill_workers_test(void) { - /* Start some workers. */ - int num_workers = 4; - LocalSchedulerMock *local_scheduler = LocalSchedulerMock_init(num_workers, 0); - /* We start off with num_workers children processes, but no workers - * registered yet. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), - static_cast(num_workers)); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), 0); - - /* Make sure that each worker connects to the local_scheduler scheduler. This - * for loop will hang if one of the workers does not connect. */ - for (int i = 0; i < num_workers; ++i) { - new_client_connection(local_scheduler->loop, - local_scheduler->local_scheduler_fd, - (void *) local_scheduler->local_scheduler_state, 0); - } - - /* After handling each worker's initial connection, we should now have all - * workers accounted for, but we haven't yet matched up process IDs with our - * children processes. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), - static_cast(num_workers)); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* Each worker should register its process ID. */ - for (auto const &worker : local_scheduler->local_scheduler_state->workers) { - process_message(local_scheduler->local_scheduler_state->loop, worker->sock, - worker, 0); - } - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* After killing a worker, its state is cleaned up. */ - LocalSchedulerClient *worker = - local_scheduler->local_scheduler_state->workers.front(); - kill_worker(local_scheduler->local_scheduler_state, worker, false, false); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers - 1)); - - /* Start a worker after the local scheduler has been initialized. */ - start_worker(local_scheduler->local_scheduler_state); - /* Accept the workers as clients to the plasma manager. */ - int new_worker_fd = accept_client(local_scheduler->plasma_manager_fd); - /* The new worker should register its process ID. */ - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 1); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers - 1)); - /* Make sure the new worker connects to the local_scheduler scheduler. */ - new_client_connection(local_scheduler->loop, - local_scheduler->local_scheduler_fd, - (void *) local_scheduler->local_scheduler_state, 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 1); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - /* Make sure that the new worker registers its process ID. */ - worker = local_scheduler->local_scheduler_state->workers.back(); - process_message(local_scheduler->local_scheduler_state->loop, worker->sock, - worker, 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->child_pids.size(), 0); - ASSERT_EQ(local_scheduler->local_scheduler_state->workers.size(), - static_cast(num_workers)); - - /* Clean up. */ - close(new_worker_fd); - LocalSchedulerMock_free(local_scheduler); - PASS(); -} - -SUITE(local_scheduler_tests) { - RUN_REDIS_TEST(object_reconstruction_test); - RUN_REDIS_TEST(object_reconstruction_recursive_test); - RUN_REDIS_TEST(object_reconstruction_suppression_test); - RUN_REDIS_TEST(task_dependency_test); - RUN_REDIS_TEST(task_multi_dependency_test); - RUN_REDIS_TEST(start_kill_workers_test); -} - -GREATEST_MAIN_DEFS(); - -int main(int argc, char **argv) { - g_task_builder = make_task_builder(); - GREATEST_MAIN_BEGIN(); - RUN_SUITE(local_scheduler_tests); - GREATEST_MAIN_END(); -} diff --git a/src/local_scheduler/test/run_tests.sh b/src/local_scheduler/test/run_tests.sh deleted file mode 100644 index 9c1d7be79b78..000000000000 --- a/src/local_scheduler/test/run_tests.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -# Cause the script to exit if a single command fails. -set -e - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -sleep 1s -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -./src/plasma/plasma_store_server -s /tmp/plasma_store_socket_1 -m 100000000 & -sleep 0.5s -./src/local_scheduler/local_scheduler_tests -./src/common/thirdparty/redis/src/redis-cli shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -killall plasma_store_server diff --git a/src/local_scheduler/test/run_valgrind.sh b/src/local_scheduler/test/run_valgrind.sh deleted file mode 100644 index 6ff1dbe33c62..000000000000 --- a/src/local_scheduler/test/run_valgrind.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env bash - -# This needs to be run in the build tree, which is normally ray/build - -set -x - -# Cause the script to exit if a single command fails. -set -e - -LaunchRedis() { - port=$1 - if [[ "${RAY_USE_NEW_GCS}" = "on" ]]; then - ./src/credis/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/credis/build/src/libmember.so \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - else - ./src/common/thirdparty/redis/src/redis-server \ - --loglevel warning \ - --loadmodule ./src/common/redis_module/libray_redis_module.so \ - --port $port & - fi -} - - -# Start the Redis shards. -LaunchRedis 6379 -LaunchRedis 6380 -sleep 1s - -# Register the shard location with the primary shard. -./src/common/thirdparty/redis/src/redis-cli set NumRedisShards 1 -./src/common/thirdparty/redis/src/redis-cli rpush RedisShards 127.0.0.1:6380 - -./src/plasma/plasma_store_server -s /tmp/plasma_store_socket_1 -m 100000000 & -sleep 0.5s -valgrind --track-origins=yes --leak-check=full --show-leak-kinds=all --leak-check-heuristics=stdstring --error-exitcode=1 ./src/local_scheduler/local_scheduler_tests -./src/common/thirdparty/redis/src/redis-cli shutdown -./src/common/thirdparty/redis/src/redis-cli -p 6380 shutdown -killall plasma_store_server diff --git a/src/plasma/CMakeLists.txt b/src/plasma/CMakeLists.txt deleted file mode 100644 index 5037a54da3d7..000000000000 --- a/src/plasma/CMakeLists.txt +++ /dev/null @@ -1,61 +0,0 @@ -cmake_minimum_required(VERSION 3.4) - -project(plasma) - -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${CMAKE_CURRENT_LIST_DIR}/thirdparty) - -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --std=c99 -O3") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11 -O3 -Werror -Wall") - -if(UNIX AND NOT APPLE) - link_libraries(rt) -endif() - -include_directories("${ARROW_INCLUDE_DIR}") - -set(PLASMA_FBS_SRC "${CMAKE_CURRENT_LIST_DIR}/format/plasma.fbs" "${CMAKE_CURRENT_LIST_DIR}/format/common.fbs") -set(OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/format/) - -set(PLASMA_FBS_OUTPUT_FILES - "${OUTPUT_DIR}/plasma_generated.h" - "${OUTPUT_DIR}/common_generated.h") - -add_custom_target(gen_plasma_fbs DEPENDS ${PLASMA_FBS_OUTPUT_FILES}) -add_dependencies(gen_plasma_fbs arrow_ep) - -# Copy the fbs files from Arrow project to local directory. -add_custom_command( - OUTPUT ${PLASMA_FBS_SRC} - COMMAND mkdir -p ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMAND cp ${ARROW_SOURCE_DIR}/cpp/src/plasma/format/plasma.fbs ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMAND cp ${ARROW_SOURCE_DIR}/cpp/src/plasma/format/common.fbs ${CMAKE_CURRENT_LIST_DIR}/format/ - COMMENT "Copying ${PLASMA_FBS_SRC} to local" - VERBATIM) - -# Compile flatbuffers -add_custom_command( - OUTPUT ${PLASMA_FBS_OUTPUT_FILES} - # The --gen-object-api flag generates a C++ class MessageT for each - # flatbuffers message Message, which can be used to store deserialized - # messages in data structures. This is currently used for ObjectInfo for - # example. - COMMAND ${FLATBUFFERS_COMPILER} -c -o ${OUTPUT_DIR} ${PLASMA_FBS_SRC} --gen-object-api --scoped-enums - DEPENDS ${PLASMA_FBS_SRC} - COMMENT "Running flatc compiler on ${PLASMA_FBS_SRC}" - VERBATIM) - -include_directories("${FLATBUFFERS_INCLUDE_DIR}") - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") - -add_executable(plasma_manager - plasma_manager.cc) -add_dependencies(plasma_manager gen_plasma_fbs) - -target_link_libraries(plasma_manager common ${PLASMA_STATIC_LIB} ray_static ${ARROW_STATIC_LIB} -lpthread ${Boost_SYSTEM_LIBRARY}) - -define_test(client_tests "") -define_test(manager_tests "" plasma_manager.cc) -target_link_libraries(manager_tests ${Boost_SYSTEM_LIBRARY}) -add_dependencies(manager_tests gen_plasma_fbs) diff --git a/src/plasma/doc/plasma-doxy-config b/src/plasma/doc/plasma-doxy-config deleted file mode 100644 index 9c291f838883..000000000000 --- a/src/plasma/doc/plasma-doxy-config +++ /dev/null @@ -1,2473 +0,0 @@ -# Doxyfile 1.8.13 - -# This file describes the settings to be used by the documentation system -# doxygen (www.doxygen.org) for a project. -# -# All text after a double hash (##) is considered a comment and is placed in -# front of the TAG it is preceding. -# -# All text after a single hash (#) is considered a comment and will be ignored. -# The format is: -# TAG = value [value, ...] -# For lists, items can also be appended using: -# TAG += value [value, ...] -# Values that contain spaces should be placed between quotes (\" \"). - -#--------------------------------------------------------------------------- -# Project related configuration options -#--------------------------------------------------------------------------- - -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. -# The default value is: UTF-8. - -DOXYFILE_ENCODING = UTF-8 - -# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by -# double-quotes, unless you are using Doxywizard) that should identify the -# project for which the documentation is generated. This name is used in the -# title of most generated pages and in a few other places. -# The default value is: My Project. - -PROJECT_NAME = "Plasma" - -# The PROJECT_NUMBER tag can be used to enter a project or revision number. This -# could be handy for archiving the generated documentation or if some version -# control system is used. - -PROJECT_NUMBER = - -# Using the PROJECT_BRIEF tag one can provide an optional one line description -# for a project that appears at the top of each page and should give viewer a -# quick idea about the purpose of the project. Keep the description short. - -PROJECT_BRIEF = - -# With the PROJECT_LOGO tag one can specify a logo or an icon that is included -# in the documentation. The maximum height of the logo should not exceed 55 -# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy -# the logo to the output directory. - -PROJECT_LOGO = - -# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path -# into which the generated documentation will be written. If a relative path is -# entered, it will be relative to the location where doxygen was started. If -# left blank the current directory will be used. - -OUTPUT_DIRECTORY = - -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this -# option can be useful when feeding doxygen a huge amount of source files, where -# putting all generated files in the same directory would otherwise causes -# performance problems for the file system. -# The default value is: NO. - -CREATE_SUBDIRS = NO - -# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII -# characters to appear in the names of generated files. If set to NO, non-ASCII -# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode -# U+3044. -# The default value is: NO. - -ALLOW_UNICODE_NAMES = NO - -# The OUTPUT_LANGUAGE tag is used to specify the language in which all -# documentation generated by doxygen is written. Doxygen will use this -# information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. -# The default value is: English. - -OUTPUT_LANGUAGE = English - -# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member -# descriptions after the members that are listed in the file and class -# documentation (similar to Javadoc). Set to NO to disable this. -# The default value is: YES. - -BRIEF_MEMBER_DESC = YES - -# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief -# description of a member or function before the detailed description -# -# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the -# brief descriptions will be completely suppressed. -# The default value is: YES. - -REPEAT_BRIEF = YES - -# This tag implements a quasi-intelligent brief description abbreviator that is -# used to form the text in various listings. Each string in this list, if found -# as the leading text of the brief description, will be stripped from the text -# and the result, after processing the whole list, is used as the annotated -# text. Otherwise, the brief description is used as-is. If left blank, the -# following values are used ($name is automatically replaced with the name of -# the entity):The $name class, The $name widget, The $name file, is, provides, -# specifies, contains, represents, a, an and the. - -ABBREVIATE_BRIEF = "The $name class" \ - "The $name widget" \ - "The $name file" \ - is \ - provides \ - specifies \ - contains \ - represents \ - a \ - an \ - the - -# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then -# doxygen will generate a detailed section even if there is only a brief -# description. -# The default value is: NO. - -ALWAYS_DETAILED_SEC = NO - -# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all -# inherited members of a class in the documentation of that class as if those -# members were ordinary class members. Constructors, destructors and assignment -# operators of the base classes will not be shown. -# The default value is: NO. - -INLINE_INHERITED_MEMB = NO - -# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path -# before files name in the file list and in the header files. If set to NO the -# shortest path that makes the file name unique will be used -# The default value is: YES. - -FULL_PATH_NAMES = YES - -# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. -# Stripping is only done if one of the specified strings matches the left-hand -# part of the path. The tag can be used to show relative paths in the file list. -# If left blank the directory from which doxygen is run is used as the path to -# strip. -# -# Note that you can specify absolute paths here, but also relative paths, which -# will be relative from the directory where doxygen is started. -# This tag requires that the tag FULL_PATH_NAMES is set to YES. - -STRIP_FROM_PATH = - -# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the -# path mentioned in the documentation of a class, which tells the reader which -# header file to include in order to use a class. If left blank only the name of -# the header file containing the class definition is used. Otherwise one should -# specify the list of include paths that are normally passed to the compiler -# using the -I flag. - -STRIP_FROM_INC_PATH = - -# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but -# less readable) file names. This can be useful is your file systems doesn't -# support long names like on DOS, Mac, or CD-ROM. -# The default value is: NO. - -SHORT_NAMES = NO - -# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the -# first line (until the first dot) of a Javadoc-style comment as the brief -# description. If set to NO, the Javadoc-style will behave just like regular Qt- -# style comments (thus requiring an explicit @brief command for a brief -# description.) -# The default value is: NO. - -JAVADOC_AUTOBRIEF = NO - -# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first -# line (until the first dot) of a Qt-style comment as the brief description. If -# set to NO, the Qt-style will behave just like regular Qt-style comments (thus -# requiring an explicit \brief command for a brief description.) -# The default value is: NO. - -QT_AUTOBRIEF = NO - -# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a -# multi-line C++ special comment block (i.e. a block of //! or /// comments) as -# a brief description. This used to be the default behavior. The new default is -# to treat a multi-line C++ comment block as a detailed description. Set this -# tag to YES if you prefer the old behavior instead. -# -# Note that setting this tag to YES also means that rational rose comments are -# not recognized any more. -# The default value is: NO. - -MULTILINE_CPP_IS_BRIEF = NO - -# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the -# documentation from any documented member that it re-implements. -# The default value is: YES. - -INHERIT_DOCS = YES - -# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new -# page for each member. If set to NO, the documentation of a member will be part -# of the file/class/namespace that contains it. -# The default value is: NO. - -SEPARATE_MEMBER_PAGES = NO - -# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen -# uses this value to replace tabs by spaces in code fragments. -# Minimum value: 1, maximum value: 16, default value: 4. - -TAB_SIZE = 2 - -# This tag can be used to specify a number of aliases that act as commands in -# the documentation. An alias has the form: -# name=value -# For example adding -# "sideeffect=@par Side Effects:\n" -# will allow you to put the command \sideeffect (or @sideeffect) in the -# documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. - -ALIASES = - -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - -# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources -# only. Doxygen will then generate output that is more tailored for C. For -# instance, some of the names that are used will be different. The list of all -# members will be omitted, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_FOR_C = NO - -# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or -# Python sources only. Doxygen will then generate output that is more tailored -# for that language. For instance, namespaces will be presented as packages, -# qualified scopes will look different, etc. -# The default value is: NO. - -OPTIMIZE_OUTPUT_JAVA = NO - -# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran -# sources. Doxygen will then generate output that is tailored for Fortran. -# The default value is: NO. - -OPTIMIZE_FOR_FORTRAN = NO - -# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL -# sources. Doxygen will then generate output that is tailored for VHDL. -# The default value is: NO. - -OPTIMIZE_OUTPUT_VHDL = NO - -# Doxygen selects the parser to use depending on the extension of the files it -# parses. With this tag you can assign which parser to use for a given -# extension. Doxygen has a built-in mapping, but you can override or extend it -# using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. -# -# Note: For files without extension you can use no_extension as a placeholder. -# -# Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. - -EXTENSION_MAPPING = - -# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments -# according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. -# The output of markdown processing is further processed by doxygen, so you can -# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in -# case of backward compatibilities issues. -# The default value is: YES. - -MARKDOWN_SUPPORT = YES - -# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up -# to that level are automatically included in the table of contents, even if -# they do not have an id attribute. -# Note: This feature currently applies only to Markdown headings. -# Minimum value: 0, maximum value: 99, default value: 0. -# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. - -TOC_INCLUDE_HEADINGS = 0 - -# When enabled doxygen tries to link words that correspond to documented -# classes, or namespaces to their corresponding documentation. Such a link can -# be prevented in individual cases by putting a % sign in front of the word or -# globally by setting AUTOLINK_SUPPORT to NO. -# The default value is: YES. - -AUTOLINK_SUPPORT = YES - -# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want -# to include (a tag file for) the STL sources as input, then you should set this -# tag to YES in order to let doxygen match functions declarations and -# definitions whose arguments contain STL classes (e.g. func(std::string); -# versus func(std::string) {}). This also make the inheritance and collaboration -# diagrams that involve STL classes more complete and accurate. -# The default value is: NO. - -BUILTIN_STL_SUPPORT = NO - -# If you use Microsoft's C++/CLI language, you should set this option to YES to -# enable parsing support. -# The default value is: NO. - -CPP_CLI_SUPPORT = NO - -# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen -# will parse them like normal C++ but will assume all classes use public instead -# of private inheritance when no explicit protection keyword is present. -# The default value is: NO. - -SIP_SUPPORT = NO - -# For Microsoft's IDL there are propget and propput attributes to indicate -# getter and setter methods for a property. Setting this option to YES will make -# doxygen to replace the get and set methods by a property in the documentation. -# This will only work if the methods are indeed getting or setting a simple -# type. If this is not the case, or you want to show the methods anyway, you -# should set this option to NO. -# The default value is: YES. - -IDL_PROPERTY_SUPPORT = YES - -# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC -# tag is set to YES then doxygen will reuse the documentation of the first -# member in the group (if any) for the other members of the group. By default -# all members of a group must be documented explicitly. -# The default value is: NO. - -DISTRIBUTE_GROUP_DOC = NO - -# If one adds a struct or class to a group and this option is enabled, then also -# any nested class or struct is added to the same group. By default this option -# is disabled and one has to add nested compounds explicitly via \ingroup. -# The default value is: NO. - -GROUP_NESTED_COMPOUNDS = NO - -# Set the SUBGROUPING tag to YES to allow class member groups of the same type -# (for instance a group of public functions) to be put as a subgroup of that -# type (e.g. under the Public Functions section). Set it to NO to prevent -# subgrouping. Alternatively, this can be done per class using the -# \nosubgrouping command. -# The default value is: YES. - -SUBGROUPING = YES - -# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions -# are shown inside the group in which they are included (e.g. using \ingroup) -# instead of on a separate page (for HTML and Man pages) or section (for LaTeX -# and RTF). -# -# Note that this feature does not work in combination with -# SEPARATE_MEMBER_PAGES. -# The default value is: NO. - -INLINE_GROUPED_CLASSES = NO - -# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions -# with only public data fields or simple typedef fields will be shown inline in -# the documentation of the scope in which they are defined (i.e. file, -# namespace, or group documentation), provided this scope is documented. If set -# to NO, structs, classes, and unions are shown on a separate page (for HTML and -# Man pages) or section (for LaTeX and RTF). -# The default value is: NO. - -INLINE_SIMPLE_STRUCTS = NO - -# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or -# enum is documented as struct, union, or enum with the name of the typedef. So -# typedef struct TypeS {} TypeT, will appear in the documentation as a struct -# with name TypeT. When disabled the typedef will appear as a member of a file, -# namespace, or class. And the struct will be named TypeS. This can typically be -# useful for C code in case the coding convention dictates that all compound -# types are typedef'ed and only the typedef is referenced, never the tag name. -# The default value is: NO. - -TYPEDEF_HIDES_STRUCT = NO - -# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This -# cache is used to resolve symbols given their name and scope. Since this can be -# an expensive process and often the same symbol appears multiple times in the -# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small -# doxygen will become slower. If the cache is too large, memory is wasted. The -# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range -# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 -# symbols. At the end of a run doxygen will report the cache usage and suggest -# the optimal cache size from a speed point of view. -# Minimum value: 0, maximum value: 9, default value: 0. - -LOOKUP_CACHE_SIZE = 0 - -#--------------------------------------------------------------------------- -# Build related configuration options -#--------------------------------------------------------------------------- - -# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in -# documentation are documented, even if no documentation was available. Private -# class members and static file members will be hidden unless the -# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. -# Note: This will also disable the warnings about undocumented members that are -# normally produced when WARNINGS is set to YES. -# The default value is: NO. - -EXTRACT_ALL = YES - -# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will -# be included in the documentation. -# The default value is: NO. - -EXTRACT_PRIVATE = NO - -# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal -# scope will be included in the documentation. -# The default value is: NO. - -EXTRACT_PACKAGE = NO - -# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be -# included in the documentation. -# The default value is: NO. - -EXTRACT_STATIC = NO - -# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined -# locally in source files will be included in the documentation. If set to NO, -# only classes defined in header files are included. Does not have any effect -# for Java sources. -# The default value is: YES. - -EXTRACT_LOCAL_CLASSES = YES - -# This flag is only useful for Objective-C code. If set to YES, local methods, -# which are defined in the implementation section but not in the interface are -# included in the documentation. If set to NO, only methods in the interface are -# included. -# The default value is: NO. - -EXTRACT_LOCAL_METHODS = NO - -# If this flag is set to YES, the members of anonymous namespaces will be -# extracted and appear in the documentation as a namespace called -# 'anonymous_namespace{file}', where file will be replaced with the base name of -# the file that contains the anonymous namespace. By default anonymous namespace -# are hidden. -# The default value is: NO. - -EXTRACT_ANON_NSPACES = NO - -# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all -# undocumented members inside documented classes or files. If set to NO these -# members will be included in the various overviews, but no documentation -# section is generated. This option has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_MEMBERS = NO - -# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all -# undocumented classes that are normally visible in the class hierarchy. If set -# to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. -# The default value is: NO. - -HIDE_UNDOC_CLASSES = NO - -# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. -# The default value is: NO. - -HIDE_FRIEND_COMPOUNDS = NO - -# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any -# documentation blocks found inside the body of a function. If set to NO, these -# blocks will be appended to the function's detailed documentation block. -# The default value is: NO. - -HIDE_IN_BODY_DOCS = NO - -# The INTERNAL_DOCS tag determines if documentation that is typed after a -# \internal command is included. If the tag is set to NO then the documentation -# will be excluded. Set it to YES to include the internal documentation. -# The default value is: NO. - -INTERNAL_DOCS = NO - -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. - -CASE_SENSE_NAMES = NO - -# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with -# their full class and namespace scopes in the documentation. If set to YES, the -# scope will be hidden. -# The default value is: NO. - -HIDE_SCOPE_NAMES = NO - -# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will -# append additional text to a page's title, such as Class Reference. If set to -# YES the compound reference will be hidden. -# The default value is: NO. - -HIDE_COMPOUND_REFERENCE= NO - -# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of -# the files that are included by a file in the documentation of that file. -# The default value is: YES. - -SHOW_INCLUDE_FILES = YES - -# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each -# grouped member an include statement to the documentation, telling the reader -# which file to include in order to use the member. -# The default value is: NO. - -SHOW_GROUPED_MEMB_INC = NO - -# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include -# files with double quotes in the documentation rather than with sharp brackets. -# The default value is: NO. - -FORCE_LOCAL_INCLUDES = NO - -# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the -# documentation for inline members. -# The default value is: YES. - -INLINE_INFO = YES - -# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the -# (detailed) documentation of file and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. -# The default value is: YES. - -SORT_MEMBER_DOCS = YES - -# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief -# descriptions of file, namespace and class members alphabetically by member -# name. If set to NO, the members will appear in declaration order. Note that -# this will also influence the order of the classes in the class list. -# The default value is: NO. - -SORT_BRIEF_DOCS = NO - -# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the -# (brief and detailed) documentation of class members so that constructors and -# destructors are listed first. If set to NO the constructors will appear in the -# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. -# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief -# member documentation. -# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting -# detailed member documentation. -# The default value is: NO. - -SORT_MEMBERS_CTORS_1ST = NO - -# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy -# of group names into alphabetical order. If set to NO the group names will -# appear in their defined order. -# The default value is: NO. - -SORT_GROUP_NAMES = NO - -# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by -# fully-qualified names, including namespaces. If set to NO, the class list will -# be sorted only by class name, not including the namespace part. -# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. -# Note: This option applies only to the class list, not to the alphabetical -# list. -# The default value is: NO. - -SORT_BY_SCOPE_NAME = NO - -# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper -# type resolution of all parameters of a function it will reject a match between -# the prototype and the implementation of a member function even if there is -# only one candidate or it is obvious which candidate to choose by doing a -# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still -# accept a match between prototype and implementation in such cases. -# The default value is: NO. - -STRICT_PROTO_MATCHING = NO - -# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo -# list. This list is created by putting \todo commands in the documentation. -# The default value is: YES. - -GENERATE_TODOLIST = YES - -# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test -# list. This list is created by putting \test commands in the documentation. -# The default value is: YES. - -GENERATE_TESTLIST = YES - -# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug -# list. This list is created by putting \bug commands in the documentation. -# The default value is: YES. - -GENERATE_BUGLIST = YES - -# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) -# the deprecated list. This list is created by putting \deprecated commands in -# the documentation. -# The default value is: YES. - -GENERATE_DEPRECATEDLIST= YES - -# The ENABLED_SECTIONS tag can be used to enable conditional documentation -# sections, marked by \if ... \endif and \cond -# ... \endcond blocks. - -ENABLED_SECTIONS = - -# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the -# initial value of a variable or macro / define can have for it to appear in the -# documentation. If the initializer consists of more lines than specified here -# it will be hidden. Use a value of 0 to hide initializers completely. The -# appearance of the value of individual variables and macros / defines can be -# controlled using \showinitializer or \hideinitializer command in the -# documentation regardless of this setting. -# Minimum value: 0, maximum value: 10000, default value: 30. - -MAX_INITIALIZER_LINES = 30 - -# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at -# the bottom of the documentation of classes and structs. If set to YES, the -# list will mention the files that were used to generate the documentation. -# The default value is: YES. - -SHOW_USED_FILES = YES - -# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This -# will remove the Files entry from the Quick Index and from the Folder Tree View -# (if specified). -# The default value is: YES. - -SHOW_FILES = YES - -# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces -# page. This will remove the Namespaces entry from the Quick Index and from the -# Folder Tree View (if specified). -# The default value is: YES. - -SHOW_NAMESPACES = YES - -# The FILE_VERSION_FILTER tag can be used to specify a program or script that -# doxygen should invoke to get the current version for each file (typically from -# the version control system). Doxygen will invoke the program by executing (via -# popen()) the command command input-file, where command is the value of the -# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided -# by doxygen. Whatever the program writes to standard output is used as the file -# version. For an example see the documentation. - -FILE_VERSION_FILTER = - -# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed -# by doxygen. The layout file controls the global structure of the generated -# output files in an output format independent way. To create the layout file -# that represents doxygen's defaults, run doxygen with the -l option. You can -# optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. -# -# Note that if you run doxygen from a directory containing a file called -# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE -# tag is left empty. - -LAYOUT_FILE = - -# The CITE_BIB_FILES tag can be used to specify one or more bib files containing -# the reference definitions. This must be a list of .bib files. The .bib -# extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. -# For LaTeX the style of the bibliography can be controlled using -# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the -# search path. See also \cite for info how to create references. - -CITE_BIB_FILES = - -#--------------------------------------------------------------------------- -# Configuration options related to warning and progress messages -#--------------------------------------------------------------------------- - -# The QUIET tag can be used to turn on/off the messages that are generated to -# standard output by doxygen. If QUIET is set to YES this implies that the -# messages are off. -# The default value is: NO. - -QUIET = NO - -# The WARNINGS tag can be used to turn on/off the warning messages that are -# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES -# this implies that the warnings are on. -# -# Tip: Turn warnings on while writing the documentation. -# The default value is: YES. - -WARNINGS = YES - -# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate -# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag -# will automatically be disabled. -# The default value is: YES. - -WARN_IF_UNDOCUMENTED = YES - -# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. -# The default value is: YES. - -WARN_IF_DOC_ERROR = YES - -# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that -# are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. -# The default value is: NO. - -WARN_NO_PARAMDOC = NO - -# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when -# a warning is encountered. -# The default value is: NO. - -WARN_AS_ERROR = NO - -# The WARN_FORMAT tag determines the format of the warning messages that doxygen -# can produce. The string should contain the $file, $line, and $text tags, which -# will be replaced by the file and line number from which the warning originated -# and the warning text. Optionally the format may contain $version, which will -# be replaced by the version of the file (if it could be obtained via -# FILE_VERSION_FILTER) -# The default value is: $file:$line: $text. - -WARN_FORMAT = "$file:$line: $text" - -# The WARN_LOGFILE tag can be used to specify a file to which warning and error -# messages should be written. If left blank the output is written to standard -# error (stderr). - -WARN_LOGFILE = - -#--------------------------------------------------------------------------- -# Configuration options related to the input files -#--------------------------------------------------------------------------- - -# The INPUT tag is used to specify the files and/or directories that contain -# documented source files. You may enter file names like myfile.cpp or -# directories like /usr/src/myproject. Separate the files or directories with -# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING -# Note: If this tag is empty the current directory is searched. - -INPUT = ../src - -# This tag can be used to specify the character encoding of the source files -# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses -# libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. -# The default value is: UTF-8. - -INPUT_ENCODING = UTF-8 - -# If the value of the INPUT tag contains directories, you can use the -# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and -# *.h) to filter out the source-files in the directories. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# read by doxygen. -# -# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, -# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, -# *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf and *.qsf. - -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.c++ \ - *.java \ - *.ii \ - *.ixx \ - *.ipp \ - *.i++ \ - *.inl \ - *.idl \ - *.ddl \ - *.odl \ - *.h \ - *.hh \ - *.hxx \ - *.hpp \ - *.h++ \ - *.cs \ - *.d \ - *.php \ - *.php4 \ - *.php5 \ - *.phtml \ - *.inc \ - *.m \ - *.markdown \ - *.md \ - *.mm \ - *.dox \ - *.py \ - *.pyw \ - *.f90 \ - *.f95 \ - *.f03 \ - *.f08 \ - *.f \ - *.for \ - *.tcl \ - *.vhd \ - *.vhdl \ - *.ucf \ - *.qsf - -# The RECURSIVE tag can be used to specify whether or not subdirectories should -# be searched for input files as well. -# The default value is: NO. - -RECURSIVE = NO - -# The EXCLUDE tag can be used to specify files and/or directories that should be -# excluded from the INPUT source files. This way you can easily exclude a -# subdirectory from a directory tree whose root is specified with the INPUT tag. -# -# Note that relative paths are relative to the directory from which doxygen is -# run. - -EXCLUDE = ../src/utarray.h ../src/uthash.h - -# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or -# directories that are symbolic links (a Unix file system feature) are excluded -# from the input. -# The default value is: NO. - -EXCLUDE_SYMLINKS = NO - -# If the value of the INPUT tag contains directories, you can use the -# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude -# certain files from those directories. -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories for example use the pattern */test/* - -EXCLUDE_PATTERNS = - -# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names -# (namespaces, classes, functions, etc.) that should be excluded from the -# output. The symbol name can be a fully qualified name, a word, or if the -# wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* - -EXCLUDE_SYMBOLS = - -# The EXAMPLE_PATH tag can be used to specify one or more files or directories -# that contain example code fragments that are included (see the \include -# command). - -EXAMPLE_PATH = - -# If the value of the EXAMPLE_PATH tag contains directories, you can use the -# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and -# *.h) to filter out the source-files in the directories. If left blank all -# files are included. - -EXAMPLE_PATTERNS = * - -# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be -# searched for input files to be used with the \include or \dontinclude commands -# irrespective of the value of the RECURSIVE tag. -# The default value is: NO. - -EXAMPLE_RECURSIVE = NO - -# The IMAGE_PATH tag can be used to specify one or more files or directories -# that contain images that are to be included in the documentation (see the -# \image command). - -IMAGE_PATH = - -# The INPUT_FILTER tag can be used to specify a program that doxygen should -# invoke to filter for each input file. Doxygen will invoke the filter program -# by executing (via popen()) the command: -# -# -# -# where is the value of the INPUT_FILTER tag, and is the -# name of an input file. Doxygen will then use the output that the filter -# program writes to standard output. If FILTER_PATTERNS is specified, this tag -# will be ignored. -# -# Note that the filter must not add or remove lines; it is applied before the -# code is scanned, but not when the output code is generated. If lines are added -# or removed, the anchors will not be placed correctly. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -INPUT_FILTER = - -# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern -# basis. Doxygen will compare the file name with each pattern and apply the -# filter if there is a match. The filters are a list of the form: pattern=filter -# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how -# filters are used. If the FILTER_PATTERNS tag is empty or if none of the -# patterns match the file name, INPUT_FILTER is applied. -# -# Note that for custom extensions or not directly supported extensions you also -# need to set EXTENSION_MAPPING for the extension otherwise the files are not -# properly processed by doxygen. - -FILTER_PATTERNS = - -# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using -# INPUT_FILTER) will also be used to filter the input files that are used for -# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). -# The default value is: NO. - -FILTER_SOURCE_FILES = NO - -# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file -# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and -# it is also possible to disable source filtering for a specific pattern using -# *.ext= (so without naming a filter). -# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. - -FILTER_SOURCE_PATTERNS = - -# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that -# is part of the input, its contents will be placed on the main page -# (index.html). This can be useful if you have a project on for instance GitHub -# and want to reuse the introduction page also for the doxygen output. - -USE_MDFILE_AS_MAINPAGE = - -#--------------------------------------------------------------------------- -# Configuration options related to source browsing -#--------------------------------------------------------------------------- - -# If the SOURCE_BROWSER tag is set to YES then a list of source files will be -# generated. Documented entities will be cross-referenced with these sources. -# -# Note: To get rid of all source code in the generated output, make sure that -# also VERBATIM_HEADERS is set to NO. -# The default value is: NO. - -SOURCE_BROWSER = NO - -# Setting the INLINE_SOURCES tag to YES will include the body of functions, -# classes and enums directly into the documentation. -# The default value is: NO. - -INLINE_SOURCES = NO - -# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any -# special comment blocks from generated source code fragments. Normal C, C++ and -# Fortran comments will always remain visible. -# The default value is: YES. - -STRIP_CODE_COMMENTS = YES - -# If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. -# The default value is: NO. - -REFERENCED_BY_RELATION = NO - -# If the REFERENCES_RELATION tag is set to YES then for each documented function -# all documented entities called/used by that function will be listed. -# The default value is: NO. - -REFERENCES_RELATION = NO - -# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set -# to YES then the hyperlinks from functions in REFERENCES_RELATION and -# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will -# link to the documentation. -# The default value is: YES. - -REFERENCES_LINK_SOURCE = YES - -# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the -# source code will show a tooltip with additional information such as prototype, -# brief description and links to the definition and documentation. Since this -# will make the HTML file larger and loading of large files a bit slower, you -# can opt to disable this feature. -# The default value is: YES. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -SOURCE_TOOLTIPS = YES - -# If the USE_HTAGS tag is set to YES then the references to source code will -# point to the HTML generated by the htags(1) tool instead of doxygen built-in -# source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version -# 4.8.6 or higher. -# -# To use it do the following: -# - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file -# - Make sure the INPUT points to the root of the source tree -# - Run doxygen as normal -# -# Doxygen will invoke htags (and that will in turn invoke gtags), so these -# tools must be available from the command line (i.e. in the search path). -# -# The result: instead of the source browser generated by doxygen, the links to -# source code will now point to the output of htags. -# The default value is: NO. -# This tag requires that the tag SOURCE_BROWSER is set to YES. - -USE_HTAGS = NO - -# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a -# verbatim copy of the header file for each class for which an include is -# specified. Set to NO to disable this. -# See also: Section \class. -# The default value is: YES. - -VERBATIM_HEADERS = YES - -#--------------------------------------------------------------------------- -# Configuration options related to the alphabetical class index -#--------------------------------------------------------------------------- - -# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all -# compounds will be generated. Enable this if the project contains a lot of -# classes, structs, unions or interfaces. -# The default value is: YES. - -ALPHABETICAL_INDEX = YES - -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -IGNORE_PREFIX = - -#--------------------------------------------------------------------------- -# Configuration options related to the HTML output -#--------------------------------------------------------------------------- - -# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output -# The default value is: YES. - -GENERATE_HTML = YES - -# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a -# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of -# it. -# The default directory is: html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_OUTPUT = html - -# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each -# generated HTML page (for example: .htm, .php, .asp). -# The default value is: .html. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FILE_EXTENSION = .html - -# The HTML_HEADER tag can be used to specify a user-defined HTML header file for -# each generated HTML page. If the tag is left blank doxygen will generate a -# standard header. -# -# To get valid HTML the header file that includes any scripts and style sheets -# that doxygen needs, which is dependent on the configuration options used (e.g. -# the setting GENERATE_TREEVIEW). It is highly recommended to start with a -# default header using -# doxygen -w html new_header.html new_footer.html new_stylesheet.css -# YourConfigFile -# and then modify the file new_header.html. See also section "Doxygen usage" -# for information on how to generate the default header that doxygen normally -# uses. -# Note: The header is subject to change so you typically have to regenerate the -# default header when upgrading to a newer version of doxygen. For a description -# of the possible markers and block names see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_HEADER = - -# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each -# generated HTML page. If the tag is left blank doxygen will generate a standard -# footer. See HTML_HEADER for more information on how to generate a default -# footer and what special commands can be used inside the footer. See also -# section "Doxygen usage" for information on how to generate the default footer -# that doxygen normally uses. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_FOOTER = - -# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style -# sheet that is used by each HTML page. It can be used to fine-tune the look of -# the HTML output. If left blank doxygen will generate a default style sheet. -# See also section "Doxygen usage" for information on how to generate the style -# sheet that doxygen normally uses. -# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as -# it is more robust and this tag (HTML_STYLESHEET) will in the future become -# obsolete. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_STYLESHEET = - -# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined -# cascading style sheets that are included after the standard style sheets -# created by doxygen. Using this option one can overrule certain style aspects. -# This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefore more robust against future updates. -# Doxygen will copy the style sheet files to the output directory. -# Note: The order of the extra style sheet files is of importance (e.g. the last -# style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_STYLESHEET = - -# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or -# other source files which should be copied to the HTML output directory. Note -# that these files will be copied to the base HTML output directory. Use the -# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these -# files. In the HTML_STYLESHEET file, use the file name only. Also note that the -# files will be copied as-is; there are no commands or markers available. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_EXTRA_FILES = - -# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen -# will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value -# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 -# purple, and 360 is red again. -# Minimum value: 0, maximum value: 359, default value: 220. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_HUE = 220 - -# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A -# value of 255 will produce the most vivid colors. -# Minimum value: 0, maximum value: 255, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_SAT = 100 - -# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the -# luminance component of the colors in the HTML output. Values below 100 -# gradually make the output lighter, whereas values above 100 make the output -# darker. The value divided by 100 is the actual gamma applied, so 80 represents -# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not -# change the gamma. -# Minimum value: 40, maximum value: 240, default value: 80. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_COLORSTYLE_GAMMA = 80 - -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_TIMESTAMP = NO - -# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML -# documentation will contain sections that can be hidden and shown after the -# page has loaded. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_DYNAMIC_SECTIONS = NO - -# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries -# shown in the various tree structured indices initially; the user can expand -# and collapse entries dynamically later on. Doxygen will expand the tree to -# such a level that at most the specified number of entries are visible (unless -# a fully collapsed tree already exceeds this amount). So setting the number of -# entries 1 will produce a full collapsed tree by default. 0 is a special value -# representing an infinite number of entries and will result in a full expanded -# tree by default. -# Minimum value: 0, maximum value: 9999, default value: 100. -# This tag requires that the tag GENERATE_HTML is set to YES. - -HTML_INDEX_NUM_ENTRIES = 100 - -# If the GENERATE_DOCSET tag is set to YES, additional index files will be -# generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in -# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_DOCSET = NO - -# This tag determines the name of the docset feed. A documentation feed provides -# an umbrella under which multiple documentation sets from a single provider -# (such as a company or product suite) can be grouped. -# The default value is: Doxygen generated docs. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_FEEDNAME = "Doxygen generated docs" - -# This tag specifies a string that should uniquely identify the documentation -# set bundle. This should be a reverse domain-name style string, e.g. -# com.mycompany.MyDocSet. Doxygen will append .docset to the name. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_BUNDLE_ID = org.doxygen.Project - -# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify -# the documentation publisher. This should be a reverse domain-name style -# string, e.g. com.mycompany.MyDocSet.documentation. -# The default value is: org.doxygen.Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_ID = org.doxygen.Publisher - -# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. -# The default value is: Publisher. -# This tag requires that the tag GENERATE_DOCSET is set to YES. - -DOCSET_PUBLISHER_NAME = Publisher - -# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three -# additional HTML index files: index.hhp, index.hhc, and index.hhk. The -# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. -# -# The HTML Help Workshop contains a compiler that can convert all HTML output -# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML -# files are now used as the Windows 98 help format, and will replace the old -# Windows help format (.hlp) on all Windows platforms in the future. Compressed -# HTML files also contain an index, a table of contents, and you can search for -# words in the documentation. The HTML workshop also contains a viewer for -# compressed HTML files. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_HTMLHELP = NO - -# The CHM_FILE tag can be used to specify the file name of the resulting .chm -# file. You can add a path in front of the file if the result should not be -# written to the html output directory. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_FILE = - -# The HHC_LOCATION tag can be used to specify the location (absolute path -# including file name) of the HTML help compiler (hhc.exe). If non-empty, -# doxygen will try to run the HTML help compiler on the generated index.hhp. -# The file has to be specified with full path. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -HHC_LOCATION = - -# The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -GENERATE_CHI = NO - -# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) -# and project file content. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -CHM_INDEX_ENCODING = - -# The BINARY_TOC flag controls whether a binary table of contents is generated -# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it -# enables the Previous and Next buttons. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -BINARY_TOC = NO - -# The TOC_EXPAND flag can be set to YES to add extra items for group members to -# the table of contents of the HTML help documentation and to the tree view. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTMLHELP is set to YES. - -TOC_EXPAND = NO - -# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and -# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that -# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help -# (.qch) of the generated HTML documentation. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_QHP = NO - -# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify -# the file name of the resulting .qch file. The path specified is relative to -# the HTML output folder. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QCH_FILE = - -# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help -# Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_NAMESPACE = org.doxygen.Project - -# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt -# Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). -# The default value is: doc. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_VIRTUAL_FOLDER = doc - -# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom -# filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_NAME = - -# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the -# custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_CUST_FILTER_ATTRS = - -# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this -# project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHP_SECT_FILTER_ATTRS = - -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. -# This tag requires that the tag GENERATE_QHP is set to YES. - -QHG_LOCATION = - -# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be -# generated, together with the HTML files, they form an Eclipse help plugin. To -# install this plugin and make it available under the help contents menu in -# Eclipse, the contents of the directory containing the HTML and XML files needs -# to be copied into the plugins directory of eclipse. The name of the directory -# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. -# After copying Eclipse needs to be restarted before the help appears. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_ECLIPSEHELP = NO - -# A unique identifier for the Eclipse help plugin. When installing the plugin -# the directory name containing the HTML and XML files should also have this -# name. Each documentation set should have its own identifier. -# The default value is: org.doxygen.Project. -# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. - -ECLIPSE_DOC_ID = org.doxygen.Project - -# If you want full control over the layout of the generated HTML pages it might -# be necessary to disable the index and replace it with your own. The -# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top -# of each HTML page. A value of NO enables the index and the value YES disables -# it. Since the tabs in the index contain the same information as the navigation -# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -DISABLE_INDEX = NO - -# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index -# structure should be generated to display hierarchical information. If the tag -# value is set to YES, a side panel will be generated containing a tree-like -# index structure (just like the one that is generated for HTML Help). For this -# to work a browser that supports JavaScript, DHTML, CSS and frames is required -# (i.e. any modern browser). Windows users are probably better off using the -# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -GENERATE_TREEVIEW = NO - -# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that -# doxygen will group on one line in the generated HTML documentation. -# -# Note that a value of 0 will completely suppress the enum values from appearing -# in the overview section. -# Minimum value: 0, maximum value: 20, default value: 4. -# This tag requires that the tag GENERATE_HTML is set to YES. - -ENUM_VALUES_PER_LINE = 4 - -# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used -# to set the initial width (in pixels) of the frame in which the tree is shown. -# Minimum value: 0, maximum value: 1500, default value: 250. -# This tag requires that the tag GENERATE_HTML is set to YES. - -TREEVIEW_WIDTH = 250 - -# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to -# external symbols imported via tag files in a separate window. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -EXT_LINKS_IN_WINDOW = NO - -# Use this tag to change the font size of LaTeX formulas included as images in -# the HTML documentation. When you change the font size after a successful -# doxygen run you need to manually remove any form_*.png images from the HTML -# output directory to force them to be regenerated. -# Minimum value: 8, maximum value: 50, default value: 10. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_FONTSIZE = 10 - -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. - -FORMULA_TRANSPARENT = YES - -# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering -# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX -# installed or if you want to formulas look prettier in the HTML output. When -# enabled you may also need to install MathJax separately and configure the path -# to it using the MATHJAX_RELPATH option. -# The default value is: NO. -# This tag requires that the tag GENERATE_HTML is set to YES. - -USE_MATHJAX = NO - -# When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. -# Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. -# The default value is: HTML-CSS. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_FORMAT = HTML-CSS - -# When MathJax is enabled you need to specify the location relative to the HTML -# output directory using the MATHJAX_RELPATH option. The destination directory -# should contain the MathJax.js script. For instance, if the mathjax directory -# is located at the same level as the HTML output directory, then -# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax -# Content Delivery Network so you can quickly see the result without installing -# MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest - -# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax -# extension names that should be enabled during MathJax rendering. For example -# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_EXTENSIONS = - -# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces -# of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an -# example see the documentation. -# This tag requires that the tag USE_MATHJAX is set to YES. - -MATHJAX_CODEFILE = - -# When the SEARCHENGINE tag is enabled doxygen will generate a search box for -# the HTML output. The underlying search engine uses javascript and DHTML and -# should work on any modern browser. Note that when using HTML help -# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) -# there is already a search function so this one should typically be disabled. -# For large projects the javascript based search engine can be slow, then -# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to -# search using the keyboard; to jump to the search box use + S -# (what the is depends on the OS and browser, but it is typically -# , /