{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rA5Mubike7OJ" }, "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "fY0a3LRYfHUl" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "iNz7xXMSsAQa" }, "source": [ "# Parameter server training with ParameterServerStrategy" ] }, { "cell_type": "markdown", "metadata": { "id": "jHyqRIqxsJuc" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "6v4D6QfcfTrm" }, "source": [ "## Overview\n", "\n", "[Parameter server training](https://www.usenix.org/system/files/conference/osdi14/osdi14-paper-li_mu.pdf) is a common data-parallel method to scale up model training on multiple machines.\n", "\n", "A parameter server training cluster consists of _workers_ and _parameter servers_. Variables are created on parameter servers and they are read and updated by workers in each step. By default, workers read and update these variables independently without synchronizing with each other. This is why sometimes parameter server-style training is called _asynchronous training_.\n", "\n", "In TensorFlow 2, parameter server training is powered by the `tf.distribute.ParameterServerStrategy` class, which distributes the training steps to a cluster that scales up to thousands of workers (accompanied by parameter servers)." ] }, { "cell_type": "markdown", "metadata": { "id": "W1LGfTdgOF-J" }, "source": [ "### Supported training methods\n", "\n", "There are two main supported training methods:\n", "\n", "- The Keras `Model.fit` API: if you prefer a high-level abstraction and handling of training. This is generally recommended if you are training a `tf.keras.Model`.\n", "- A custom training loop: if you prefer to define the details of your training loop (you can refer to guides on [Custom training](../customization/custom_training_walkthrough.ipynb), [Writing a training loop from scratch\n", "](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch) and [Custom training loop with Keras and MultiWorkerMirroredStrategy](multi_worker_with_ctl.ipynb) for more details)." ] }, { "cell_type": "markdown", "metadata": { "id": "FjbULGvV7NRz" }, "source": [ "### A cluster with jobs and tasks\n", "\n", "Regardless of the API of choice (`Model.fit` or a custom training loop), distributed training in TensorFlow 2 involves a `'cluster'` with several `'jobs'`, and each of the jobs may have one or more `'tasks'`.\n", "\n", "When using parameter server training, it is recommended to have:\n", "\n", "- One _coordinator_ job (which has the job name `chief`)\n", "- Multiple _worker_ jobs (job name `worker`)\n", "- Multiple _parameter server_ jobs (job name `ps`)\n", "\n", "The _coordinator_ creates resources, dispatches training tasks, writes checkpoints, and deals with task failures. The _workers_ and _parameter servers_ run `tf.distribute.Server` instances that listen for requests from the coordinator." ] }, { "cell_type": "markdown", "metadata": { "id": "oLV1FbpLtqtB" }, "source": [ "### Parameter server training with the `Model.fit` API\n", "\n", "Parameter server training with the `Model.fit` API requires the coordinator to use a `tf.distribute.ParameterServerStrategy` object. Similar to `Model.fit` usage with no strategy, or with other strategies, the workflow involves creating and compiling the model, preparing the callbacks, and calling `Model.fit`." ] }, { "cell_type": "markdown", "metadata": { "id": "yJ5AosxFyfzk" }, "source": [ "### Parameter server training with a custom training loop\n", "\n", "With custom training loops, the `tf.distribute.coordinator.ClusterCoordinator` class is the key component used for the coordinator.\n", "\n", "- The `ClusterCoordinator` class needs to work in conjunction with a `tf.distribute.ParameterServerStrategy` object.\n", "- This `tf.distribute.Strategy` object is needed to provide the information of the cluster and is used to define a training step, as demonstrated in [Custom training with tf.distribute.Strategy](custom_training.ipynb).\n", "- The `ClusterCoordinator` object then dispatches the execution of these training steps to remote workers.\n", "\n", "The most important API provided by the `ClusterCoordinator` object is `schedule`:\n", "\n", "- The `schedule` API enqueues a `tf.function` and returns a future-like `RemoteValue` immediately.\n", "- The queued functions will be dispatched to remote workers in background threads and their `RemoteValue`s will be filled asynchronously.\n", "- Since `schedule` doesn’t require worker assignment, the `tf.function` passed in can be executed on any available worker.\n", "- If the worker it is executed on becomes unavailable before its completion, the function will be retried on another available worker.\n", "- Because of this fact and the fact that function execution is not atomic, a single function call may be executed more than once.\n", "\n", "In addition to dispatching remote functions, the `ClusterCoordinator` also helps\n", "to create datasets on all the workers and rebuild these datasets when a worker recovers from failure." ] }, { "cell_type": "markdown", "metadata": { "id": "MyDnWjmOje5-" }, "source": [ "## Tutorial setup\n", "\n", "The tutorial will branch into `Model.fit` and custom training loop paths, and you can choose the one that fits your needs. Sections other than \"Training with X\" are applicable to both paths." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0-V3LUcIs4a-" }, "outputs": [], "source": [ "!pip install portpicker" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GlI_NAVFae3J" }, "outputs": [], "source": [ "#@title\n", "import multiprocessing\n", "import os\n", "import random\n", "import portpicker\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": { "id": "uvwgM2rzgzIC" }, "source": [ "## Cluster setup\n", "\n", "As mentioned above, a parameter server training cluster requires a coordinator task that runs your training program, one or several workers and parameter server tasks that run TensorFlow servers—`tf.distribute.Server`—and possibly an additional evaluation task that runs sidecar evaluation (refer to the [sidecar evaluation section](#sidecar_evaluation) below). The requirements to set them up are:\n", "\n", "- The coordinator task needs to know the addresses and ports of all other TensorFlow servers, except the evaluator.\n", "- The workers and parameter servers need to know which port they need to listen to. For the sake of simplicity, you can usually pass in the complete cluster information when creating TensorFlow servers on these tasks.\n", "- The evaluator task doesn’t have to know the setup of the training cluster. If it does, it should not attempt to connect to the training cluster.\n", "- Workers and parameter servers should have task types as `\"worker\"` and `\"ps\"`, respectively. The coordinator should use `\"chief\"` as the task type for legacy reasons.\n", "\n", "In this tutorial, you will create an in-process cluster so that the whole parameter server training can be run in Colab. You will learn how to set up [real clusters](#real_clusters) in a later section." ] }, { "cell_type": "markdown", "metadata": { "id": "7UNs7Lm2g19n" }, "source": [ "### In-process cluster\n", "\n", "You will start by creating several TensorFlow servers in advance and you will connect to them later. Note that this is only for the purpose of this tutorial's demonstration, and in real training the servers will be started on `\"worker\"` and `\"ps\"` machines." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FbrP5pXuaoVH" }, "outputs": [], "source": [ "def create_in_process_cluster(num_workers, num_ps):\n", " \"\"\"Creates and starts local servers and returns the cluster_resolver.\"\"\"\n", " worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]\n", " ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]\n", "\n", " cluster_dict = {}\n", " cluster_dict[\"worker\"] = [\"localhost:%s\" % port for port in worker_ports]\n", " if num_ps > 0:\n", " cluster_dict[\"ps\"] = [\"localhost:%s\" % port for port in ps_ports]\n", "\n", " cluster_spec = tf.train.ClusterSpec(cluster_dict)\n", "\n", " # Workers need some inter_ops threads to work properly.\n", " worker_config = tf.compat.v1.ConfigProto()\n", " if multiprocessing.cpu_count() < num_workers + 1:\n", " worker_config.inter_op_parallelism_threads = num_workers + 1\n", "\n", " for i in range(num_workers):\n", " tf.distribute.Server(\n", " cluster_spec,\n", " job_name=\"worker\",\n", " task_index=i,\n", " config=worker_config,\n", " protocol=\"grpc\")\n", "\n", " for i in range(num_ps):\n", " tf.distribute.Server(\n", " cluster_spec,\n", " job_name=\"ps\",\n", " task_index=i,\n", " protocol=\"grpc\")\n", "\n", " cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(\n", " cluster_spec, rpc_layer=\"grpc\")\n", " return cluster_resolver\n", "\n", "# Set the environment variable to allow reporting worker and ps failure to the\n", "# coordinator. This is a workaround and won't be necessary in the future.\n", "os.environ[\"GRPC_FAIL_FAST\"] = \"use_caller\"\n", "\n", "NUM_WORKERS = 3\n", "NUM_PS = 2\n", "cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)" ] }, { "cell_type": "markdown", "metadata": { "id": "pX_91OByt0J2" }, "source": [ "The in-process cluster setup is frequently used in unit testing, such as [here](https://github.com/tensorflow/tensorflow/blob/eb4c40fc91da260199fa2aed6fe67d36ad49fafd/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py#L447).\n", "\n", "Another option for local testing is to launch processes on the local machine—check out [Multi-worker training with Keras](multi_worker_with_keras.ipynb) for an example of this approach." ] }, { "cell_type": "markdown", "metadata": { "id": "zyby6M2Jqg6J" }, "source": [ "## Instantiate a ParameterServerStrategy\n", "\n", "Before you dive into the training code, let's instantiate a `tf.distribute.ParameterServerStrategy` object. Note that this is needed regardless of whether you are proceeding with `Model.fit` or a custom training loop. The `variable_partitioner` argument will be explained in the [Variable sharding section](#variable_sharding)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_YyEPgisrC35" }, "outputs": [], "source": [ "variable_partitioner = (\n", " tf.distribute.experimental.partitioners.MinSizePartitioner(\n", " min_shard_bytes=(256 << 10),\n", " max_shards=NUM_PS))\n", "\n", "strategy = tf.distribute.ParameterServerStrategy(\n", " cluster_resolver,\n", " variable_partitioner=variable_partitioner)" ] }, { "cell_type": "markdown", "metadata": { "id": "WlAQxuMDJ3k9" }, "source": [ "In order to use GPUs for training, allocate GPUs visible to each worker. `ParameterServerStrategy` will use all the available GPUs on each worker, with the restriction that all workers should have the same number of GPUs available." ] }, { "cell_type": "markdown", "metadata": { "id": "QMmBLsf6sEXh" }, "source": [ "### Variable sharding\n", "\n", "Variable sharding refers to splitting a variable into multiple smaller\n", "variables, which are called _shards_. Variable sharding may be useful to distribute the network load when accessing these shards. It is also useful to distribute computation and storage of a normal variable across multiple parameter servers, for example, when using very large embeddings\n", "that may not fit in a single machine's memory.\n", "\n", "To enable variable sharding, you can pass in a `variable_partitioner` when\n", "constructing a `ParameterServerStrategy` object. The `variable_partitioner` will\n", "be invoked every time when a variable is created and it is expected to return\n", "the number of shards along each dimension of the variable. Some out-of-box\n", "`variable_partitioner`s are provided such as\n", "`tf.distribute.experimental.partitioners.MinSizePartitioner`. It is recommended to use size-based partitioners like\n", "`tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid\n", "partitioning small variables, which could have a negative impact on model training\n", "speed." ] }, { "cell_type": "markdown", "metadata": { "id": "1--SxlxtsOb7" }, "source": [ "When a `variable_partitioner` is passed in, and you create a variable directly\n", "under `Strategy.scope`, the variable will become a container type with a `variables`\n", "property, which provides access to the list of shards. In most cases, this\n", "container will be automatically converted to a Tensor by concatenating all the\n", "shards. As a result, it can be used as a normal variable. On the other hand,\n", "some TensorFlow methods such as `tf.nn.embedding_lookup` provide efficient\n", "implementation for this container type and in these methods automatic\n", "concatenation will be avoided.\n", "\n", "Refer to the API docs of `tf.distribute.ParameterServerStrategy` for more details." ] }, { "cell_type": "markdown", "metadata": { "id": "jlOq-O-26O1d" }, "source": [ "## Training with `Model.fit`\n", "\n", "\n", "Keras provides an easy-to-use training API via `Model.fit` that handles the training loop under the hood, with the flexibility of an overridable `train_step`, and callbacks which provide functionalities such as checkpoint saving or summary saving for TensorBoard. With `Model.fit`, the same training code can be used with other strategies with a simple swap of the strategy object." ] }, { "cell_type": "markdown", "metadata": { "id": "oMZ9Cu5J6ZGi" }, "source": [ "### Input data\n", "\n", "Keras `Model.fit` with `tf.distribute.ParameterServerStrategy` can take input data in the form of a `tf.data.Dataset`, `tf.distribute.DistributedDataset`, or a `tf.keras.utils.experimental.DatasetCreator`, with `Dataset` being the recommended option for ease of use. If you encounter memory issues using `Dataset`, however, you may need to use `DatasetCreator` with a callable `dataset_fn` argument (refer to the `tf.keras.utils.experimental.DatasetCreator` API documentation for details).\n", "\n", "If you transform your dataset into a `tf.data.Dataset`, you should use `Dataset.shuffle` and `Dataset.repeat`, as demonstrated in the code example below.\n", "\n", "- Keras `Model.fit` with parameter server training assumes that each worker receives the same dataset, except when it is shuffled differently. Therefore, by calling `Dataset.shuffle`, you ensure more even iterations over the data.\n", "- Because workers do not synchronize, they may finish processing their datasets at different times. Therefore, the easiest way to define epochs with parameter server training is to use `Dataset.repeat`—which repeats a dataset indefinitely when called without an argument—and specify the `steps_per_epoch` argument in the `Model.fit` call.\n", "\n", "Refer to the \"Training workflows\" section of the [tf.data guide](../../guide/data.ipynb) for more details on `shuffle` and `repeat`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "shAo1CCS7wU1" }, "outputs": [], "source": [ "global_batch_size = 64\n", "\n", "x = tf.random.uniform((10, 10))\n", "y = tf.random.uniform((10,))\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()\n", "dataset = dataset.batch(global_batch_size)\n", "dataset = dataset.prefetch(2)" ] }, { "cell_type": "markdown", "metadata": { "id": "v_jhF70K7zON" }, "source": [ "If you instead create your dataset with `tf.keras.utils.experimental.DatasetCreator`, the code in `dataset_fn` will be invoked on the input device, which is usually the CPU, on each of the worker machines.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "w60PuWrWwBD4" }, "source": [ "### Model construction and compiling\n", "\n", "Now, you will create a `tf.keras.Model`—a trivial `tf.keras.models.Sequential` model for demonstration purposes—followed by a `Model.compile` call to incorporate components, such as an optimizer, metrics, and other parameters such as `steps_per_execution`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PhTHUYaD74vT" }, "outputs": [], "source": [ "with strategy.scope():\n", " model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])\n", "\n", " model.compile(tf.keras.optimizers.legacy.SGD(), loss=\"mse\", steps_per_execution=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "nWb_Ekm377YX" }, "source": [ "### Callbacks and training\n", "\n", " \n", "\n", "Before you call Keras `Model.fit` for the actual training, prepare any needed [callbacks](https://www.tensorflow.org/guide/keras/train_and_evaluate) for common tasks, such as:\n", "\n", "- `tf.keras.callbacks.ModelCheckpoint`: saves the model at a certain frequency, such as after every epoch.\n", "- `tf.keras.callbacks.BackupAndRestore`: provides fault tolerance by backing up the model and current epoch number, if the cluster experiences unavailability (such as abort or preemption). You can then restore the training state upon a restart from a job failure, and continue training from the beginning of the interrupted epoch.\n", "- `tf.keras.callbacks.TensorBoard`: periodically writes model logs in summary files that can be visualized in the TensorBoard tool.\n", "\n", "Note: Due to performance considerations, custom callbacks cannot have batch level callbacks overridden when used with `ParameterServerStrategy`. Please modify your custom callbacks to make them epoch level calls, and adjust `steps_per_epoch` to a suitable value. In addition, `steps_per_epoch` is a required argument for `Model.fit` when used with `ParameterServerStrategy`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3ddUvUZk7_wm" }, "outputs": [], "source": [ "working_dir = \"/tmp/my_working_dir\"\n", "log_dir = os.path.join(working_dir, \"log\")\n", "ckpt_filepath = os.path.join(working_dir, \"ckpt\")\n", "backup_dir = os.path.join(working_dir, \"backup\")\n", "\n", "callbacks = [\n", " tf.keras.callbacks.TensorBoard(log_dir=log_dir),\n", " tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_filepath),\n", " tf.keras.callbacks.BackupAndRestore(backup_dir=backup_dir),\n", "]\n", "\n", "model.fit(dataset, epochs=5, steps_per_epoch=20, callbacks=callbacks)" ] }, { "cell_type": "markdown", "metadata": { "id": "uWgP1h2z8B3j" }, "source": [ "### Direct usage with `ClusterCoordinator` (optional)\n", "\n", "Even if you choose the `Model.fit` training path, you can optionally instantiate a `tf.distribute.coordinator.ClusterCoordinator` object to schedule other functions you would like to be executed on the workers. Refer to the [Training with a custom training loop](#training_with_custom_training_loop) section for more details and examples." ] }, { "cell_type": "markdown", "metadata": { "id": "GxypEyIthR0z" }, "source": [ "## Training with a custom training loop\n", "\n", " \n", "\n", "Using custom training loops with `tf.distribute.Strategy` provides great flexibility to define training loops. With the `ParameterServerStrategy` defined above (as `strategy`), you will use a `tf.distribute.coordinator.ClusterCoordinator` to dispatch the execution of training steps to remote workers.\n", "\n", "Then, you will create a model, define a dataset, and define a step function, as you have done in the training loop with other `tf.distribute.Strategy`s. You can find more details in the [Custom training with tf.distribute.Strategy](custom_training.ipynb) tutorial.\n", "\n", "To ensure efficient dataset prefetching, use the recommended distributed dataset creation APIs mentioned in the [Dispatch training steps to remote workers](#dispatch_training_steps_to_remote_workers) section below. Also, make sure to call `Strategy.run` inside `worker_fn` to take full advantage of GPUs allocated to workers. The rest of the steps are the same for training with or without GPUs.\n", "\n", "Let’s create these components in the following steps:\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4QNkCtV8VivM" }, "source": [ "### Set up the data\n", "\n", "First, write a function that creates a dataset.\n", "\n", "If you would like to preprocess the data with [Keras preprocessing layers](https://www.tensorflow.org/guide/keras/preprocessing_layers) or [Tensorflow Transform layers](https://www.tensorflow.org/tfx/tutorials/transform/simple), create these layers **outside the `dataset_fn`** and **under `Strategy.scope`**, like you would do for any other Keras layers. This is because the `dataset_fn` will be wrapped into a `tf.function` and then executed on each worker to generate the data pipeline.\n", "\n", "If you don't follow the above procedure, creating the layers might create Tensorflow states which will be lifted out of the `tf.function` to the coordinator. Thus, accessing them on workers would incur repetitive RPC calls between coordinator and workers, and cause significant slowdown.\n", "\n", "Placing the layers under `Strategy.scope` will instead create them on all workers. Then, you will apply the transformation inside the `dataset_fn` via `tf.data.Dataset.map`. Refer to _Data preprocessing_ in the [Distributed input](input.ipynb) tutorial for more information on data preprocessing with distributed input." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2GUwATssauus" }, "outputs": [], "source": [ "feature_vocab = [\n", " \"avenger\", \"ironman\", \"batman\", \"hulk\", \"spiderman\", \"kingkong\", \"wonder_woman\"\n", "]\n", "label_vocab = [\"yes\", \"no\"]\n", "\n", "with strategy.scope():\n", " feature_lookup_layer = tf.keras.layers.StringLookup(\n", " vocabulary=feature_vocab,\n", " mask_token=None)\n", " label_lookup_layer = tf.keras.layers.StringLookup(\n", " vocabulary=label_vocab,\n", " num_oov_indices=0,\n", " mask_token=None)\n", "\n", " raw_feature_input = tf.keras.layers.Input(\n", " shape=(3,),\n", " dtype=tf.string,\n", " name=\"feature\")\n", " feature_id_input = feature_lookup_layer(raw_feature_input)\n", " feature_preprocess_stage = tf.keras.Model(\n", " {\"features\": raw_feature_input},\n", " feature_id_input)\n", "\n", " raw_label_input = tf.keras.layers.Input(\n", " shape=(1,),\n", " dtype=tf.string,\n", " name=\"label\")\n", " label_id_input = label_lookup_layer(raw_label_input)\n", "\n", " label_preprocess_stage = tf.keras.Model(\n", " {\"label\": raw_label_input},\n", " label_id_input)" ] }, { "cell_type": "markdown", "metadata": { "id": "Jgp8MX_7OR_A" }, "source": [ "Generate toy examples in a dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "chIY4fFANaFH" }, "outputs": [], "source": [ "def feature_and_label_gen(num_examples=200):\n", " examples = {\"features\": [], \"label\": []}\n", " for _ in range(num_examples):\n", " features = random.sample(feature_vocab, 3)\n", " label = [\"yes\"] if \"avenger\" in features else [\"no\"]\n", " examples[\"features\"].append(features)\n", " examples[\"label\"].append(label)\n", " return examples\n", "\n", "examples = feature_and_label_gen()" ] }, { "cell_type": "markdown", "metadata": { "id": "2AtZBya7OeyZ" }, "source": [ "Then, create the training dataset wrapped in a `dataset_fn`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gs0QYRZoNbvw" }, "outputs": [], "source": [ "def dataset_fn(_):\n", " raw_dataset = tf.data.Dataset.from_tensor_slices(examples)\n", "\n", " train_dataset = raw_dataset.map(\n", " lambda x: (\n", " {\"features\": feature_preprocess_stage(x[\"features\"])},\n", " label_preprocess_stage(x[\"label\"])\n", " )).shuffle(200).batch(32).repeat()\n", " return train_dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "IT9PQexJiFtB" }, "source": [ "### Build the model\n", "\n", "Next, create the model and other objects. Make sure to create all variables under `Strategy.scope`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Quxud1uEazeo" }, "outputs": [], "source": [ "# These variables created under the `Strategy.scope` will be placed on parameter\n", "# servers in a round-robin fashion.\n", "with strategy.scope():\n", " # Create the model. The input needs to be compatible with Keras processing layers.\n", " model_input = tf.keras.layers.Input(\n", " shape=(3,), dtype=tf.int64, name=\"model_input\")\n", "\n", " emb_layer = tf.keras.layers.Embedding(\n", " input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=16384)\n", " emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)\n", " dense_output = tf.keras.layers.Dense(\n", " units=1, activation=\"sigmoid\",\n", " kernel_regularizer=tf.keras.regularizers.L2(1e-4),\n", " )(emb_output)\n", " model = tf.keras.Model({\"features\": model_input}, dense_output)\n", "\n", " optimizer = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.1)\n", " accuracy = tf.keras.metrics.Accuracy()" ] }, { "cell_type": "markdown", "metadata": { "id": "iyuxiqCQU50m" }, "source": [ "Let's confirm that the use of `FixedShardsPartitioner` split all variables into two shards and that each shard was assigned to a different parameter server:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "04r1nO4WVDO1" }, "outputs": [], "source": [ "assert len(emb_layer.weights) == 2\n", "assert emb_layer.weights[0].shape == (4, 16384)\n", "assert emb_layer.weights[1].shape == (4, 16384)\n", "\n", "print(emb_layer.weights[0].device)\n", "print(emb_layer.weights[1].device)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "lWhfXZLRiHyM" }, "source": [ "### Define the training step\n", "Third, create the training step wrapped into a `tf.function`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aNNVo0bFa1K9" }, "outputs": [], "source": [ "@tf.function\n", "def step_fn(iterator):\n", "\n", " def replica_fn(batch_data, labels):\n", " with tf.GradientTape() as tape:\n", " pred = model(batch_data, training=True)\n", " per_example_loss = tf.keras.losses.BinaryCrossentropy(\n", " reduction=tf.keras.losses.Reduction.NONE)(labels, pred)\n", " loss = tf.nn.compute_average_loss(per_example_loss)\n", " model_losses = model.losses\n", " if model_losses:\n", " loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))\n", " gradients = tape.gradient(loss, model.trainable_variables)\n", "\n", " optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n", "\n", " actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n", " accuracy.update_state(labels, actual_pred)\n", " return loss\n", "\n", " batch_data, labels = next(iterator)\n", " losses = strategy.run(replica_fn, args=(batch_data, labels))\n", " return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)" ] }, { "cell_type": "markdown", "metadata": { "id": "rvrYQUeYiLNy" }, "source": [ "In the above training step function, calling `Strategy.run` and `Strategy.reduce` in the `step_fn` can support multiple GPUs per worker. If the workers have GPUs allocated, `Strategy.run` will distribute the datasets on multiple replicas (GPUs). Their parallel calls to `tf.nn.compute_average_loss()` compute the average of the loss across the replicas (GPUs) of one worker, independent of the total number of workers." ] }, { "cell_type": "markdown", "metadata": { "id": "GPJ3PV_L2zAY" }, "source": [ "### Dispatch training steps to remote workers\n", " \n", "\n", "After all the computations are defined by `ParameterServerStrategy`, you will use the `tf.distribute.coordinator.ClusterCoordinator` class to create resources and distribute the training steps to remote workers.\n", "\n", "Let’s first create a `ClusterCoordinator` object and pass in the strategy object:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DpcMlH7Pa3DB" }, "outputs": [], "source": [ "coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy)" ] }, { "cell_type": "markdown", "metadata": { "id": "-xRIgKxciOSe" }, "source": [ "Then, create a per-worker dataset and an iterator using the `ClusterCoordinator.create_per_worker_dataset` API, which replicates the dataset to all workers. In the `per_worker_dataset_fn` below, wrapping the `dataset_fn` into `strategy.distribute_datasets_from_function` is recommended to allow efficient prefetching to GPUs seamlessly." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h9DCvTJTa4Q2" }, "outputs": [], "source": [ "@tf.function\n", "def per_worker_dataset_fn():\n", " return strategy.distribute_datasets_from_function(dataset_fn)\n", "\n", "per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)\n", "per_worker_iterator = iter(per_worker_dataset)" ] }, { "cell_type": "markdown", "metadata": { "id": "i2pnOx78iRwW" }, "source": [ "The final step is to distribute the computation to remote workers using `ClusterCoordinator.schedule`:\n", "\n", "- The `schedule` method enqueues a `tf.function` and returns a future-like `RemoteValue` immediately. The queued functions will be dispatched to remote workers in background threads and the `RemoteValue` will be filled asynchronously.\n", "- The `join` method (`ClusterCoordinator.join`) can be used to wait until all scheduled functions are executed." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gmPvactfa6Eh" }, "outputs": [], "source": [ "num_epochs = 4\n", "steps_per_epoch = 5\n", "for i in range(num_epochs):\n", " accuracy.reset_states()\n", " for _ in range(steps_per_epoch):\n", " coordinator.schedule(step_fn, args=(per_worker_iterator,))\n", " # Wait at epoch boundaries.\n", " coordinator.join()\n", " print(\"Finished epoch %d, accuracy is %f.\" % (i, accuracy.result().numpy()))" ] }, { "cell_type": "markdown", "metadata": { "id": "WBn-gn-OP3DR" }, "source": [ "Here is how you can fetch the result of a `RemoteValue`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-15a2I_lQDO1" }, "outputs": [], "source": [ "loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))\n", "print(\"Final loss is %f\" % loss.fetch())" ] }, { "cell_type": "markdown", "metadata": { "id": "htY4QKc9iXg9" }, "source": [ "Alternatively, you can launch all steps and do something while waiting for\n", "completion:\n", "\n", "```python\n", "for _ in range(total_steps):\n", " coordinator.schedule(step_fn, args=(per_worker_iterator,))\n", "while not coordinator.done():\n", " time.sleep(10)\n", " # Do something like logging metrics or writing checkpoints.\n", "```\n", "\n", "For the complete training and serving workflow for this particular example, please check out this [test](https://github.com/keras-team/keras/blob/master/keras/integration_test/parameter_server_keras_preprocessing_test.py).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "kzNsj2GR3BGs" }, "source": [ "### More about dataset creation\n", "\n", "The dataset in the above code is created using the `ClusterCoordinator.create_per_worker_dataset` API. It creates one dataset per worker and returns a container object. You can call the `iter` method on it to create a per-worker iterator. The per-worker iterator contains one iterator per worker and the corresponding slice of a worker will be substituted in the input argument of the function passed to the `ClusterCoordinator.schedule` method before the function is executed on a particular worker.\n", "\n", "The `ClusterCoordinator.schedule` method assumes workers are equivalent and thus assumes the datasets on different workers are the same (except that they may be shuffled differently). Because of this, it is also recommended to repeat datasets, and schedule a finite number of steps instead of relying on receiving an `OutOfRangeError` from a dataset.\n", "\n", "Another important note is that `tf.data` datasets don’t support implicit serialization and deserialization across task boundaries. So it is important to create the whole dataset inside the function passed to `ClusterCoordinator.create_per_worker_dataset`. The `create_per_worker_dataset` API can also directly take a `tf.data.Dataset` or `tf.distribute.DistributedDataset` as input." ] }, { "cell_type": "markdown", "metadata": { "id": "LcfdI_M83lAM" }, "source": [ "## Evaluation\n", "\n", "The two main approaches to performing evaluation with `tf.distribute.ParameterServerStrategy` training are inline evaluation and sidecar evaluation. Each has its own pros and cons as described below. The inline evaluation method is recommended if you don't have a preference. For users using `Model.fit`, `Model.evaluate` uses inline (distributed) evaluation under the hood." ] }, { "cell_type": "markdown", "metadata": { "id": "oiG8EhcY3gA1" }, "source": [ "### Inline evaluation\n", "\n", "In this method, the coordinator alternates between training and evaluation, and thus it is called _inline evaluation_.\n", "\n", "There are several benefits of inline evaluation. For example:\n", "\n", "- It can support large evaluation models and evaluation datasets that a single task cannot hold.\n", "- The evaluation results can be used to make decisions for training the next epoch, for example, whether to stop training early.\n", "\n", "There are two ways to implement inline evaluation: direct evaluation and distributed evaluation.\n", "\n", "- **Direct evaluation**: For small models and evaluation datasets, the coordinator can run evaluation directly on the distributed model with the evaluation dataset on the coordinator:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WakiAakoaHVn" }, "outputs": [], "source": [ "eval_dataset = tf.data.Dataset.from_tensor_slices(\n", " feature_and_label_gen(num_examples=16)).map(\n", " lambda x: (\n", " {\"features\": feature_preprocess_stage(x[\"features\"])},\n", " label_preprocess_stage(x[\"label\"])\n", " )).batch(8)\n", "\n", "eval_accuracy = tf.keras.metrics.Accuracy()\n", "\n", "for batch_data, labels in eval_dataset:\n", " pred = model(batch_data, training=False)\n", " actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n", " eval_accuracy.update_state(labels, actual_pred)\n", "\n", "print(\"Evaluation accuracy: %f\" % eval_accuracy.result())" ] }, { "cell_type": "markdown", "metadata": { "id": "MKGHbdI7aGoJ" }, "source": [ "- **Distributed evaluation**: For large models or datasets that are infeasible to run directly on the coordinator, the coordinator task can distribute evaluation tasks to the workers via the `ClusterCoordinator.schedule`/`ClusterCoordinator.join` methods:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XcHNHJpDgEvK" }, "outputs": [], "source": [ "with strategy.scope():\n", " # Define the eval metric on parameter servers.\n", " eval_accuracy = tf.keras.metrics.Accuracy()\n", "\n", "@tf.function\n", "def eval_step(iterator):\n", " def replica_fn(batch_data, labels):\n", " pred = model(batch_data, training=False)\n", " actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)\n", " eval_accuracy.update_state(labels, actual_pred)\n", " batch_data, labels = next(iterator)\n", " strategy.run(replica_fn, args=(batch_data, labels))\n", "\n", "def eval_dataset_fn():\n", " return tf.data.Dataset.from_tensor_slices(\n", " feature_and_label_gen(num_examples=16)).map(\n", " lambda x: (\n", " {\"features\": feature_preprocess_stage(x[\"features\"])},\n", " label_preprocess_stage(x[\"label\"])\n", " )).shuffle(16).repeat().batch(8)\n", "\n", "per_worker_eval_dataset = coordinator.create_per_worker_dataset(eval_dataset_fn)\n", "per_worker_eval_iterator = iter(per_worker_eval_dataset)\n", "\n", "eval_steps_per_epoch = 2\n", "for _ in range(eval_steps_per_epoch):\n", " coordinator.schedule(eval_step, args=(per_worker_eval_iterator,))\n", "coordinator.join()\n", "print(\"Evaluation accuracy: %f\" % eval_accuracy.result())" ] }, { "cell_type": "markdown", "metadata": { "id": "cKrQktZX5z7a" }, "source": [ "#### Enabling exactly-once evaluation\n", "\n", "\n", "The `schedule` and `join` methods of `tf.distribute.coordinator.ClusterCoordinator` don’t support visitation guarantees or exactly-once semantics by default. In other words, in the above example there is no guarantee that all evaluation examples in a dataset will be evaluated exactly once; some may not be visited and some may be evaluated multiple times.\n", "\n", "Exactly-once evaluation may be preferred to reduce the variance of evaluation across epochs, and improve model selection done via early stopping, hyperparameter tuning, or other methods. There are different ways to enable exactly-once evaluation:\n", "\n", "- With a `Model.fit/.evaluate` workflow, it can be enabled by adding an argument to `Model.compile`. Refer to docs for the `pss_evaluation_shards` argument.\n", "- The `tf.data` service API can be used to provide exactly-once visitation for evaluation when using `ParameterServerStrategy` (refer to the _Dynamic Sharding_ section of the `tf.data.experimental.service` API documentation).\n", "- [Sidecar evaluation](#sidecar_evaluation) provides exactly-once evaluation by default, since the evaluation happens on a single machine. However this can be much slower than performing evaluation distributed across many workers.\n", "\n", "The first option, using `Model.compile`, is the suggested solution for most users.\n", "\n", "Exactly-once evaluation has some limitations:\n", "\n", "- It is not supported to write a custom distributed evaluation loop with an exactly-once visitation guarantee. File a GitHub issue if you need support for this.\n", "- It cannot automatically handle computation of metrics that use the `Layer.add_metric` API. These should be excluded from evaluation, or reworked into `Metric` objects." ] }, { "cell_type": "markdown", "metadata": { "id": "H40X-9Gs3i7_" }, "source": [ "### Sidecar evaluation\n", "\n", "\n", "Another method for defining and running an evaluation loop in `tf.distribute.ParameterServerStrategy` training is called _sidecar evaluation_, in which you create a dedicated evaluator task that repeatedly reads checkpoints and runs evaluation on the latest checkpoint (refer to [this guide](../../guide/checkpoint.ipynb) for more details on checkpointing). The coordinator and worker tasks do not spend any time on evaluation, so for a fixed number of iterations the overall training time should be shorter than using other evaluation methods. However, it requires an additional evaluator task and periodic checkpointing to trigger evaluation." ] }, { "cell_type": "markdown", "metadata": { "id": "HonyjnXK9-ys" }, "source": [ "To write an evaluation loop for sidecar evaluation, you have two\n", "options:\n", "\n", "1. Use the `tf.keras.utils.SidecarEvaluator` API.\n", "2. Create a custom evaluation loop.\n", "\n", "Refer to the `tf.keras.utils.SidecarEvaluator` API documentation for more details on option 1." ] }, { "cell_type": "markdown", "metadata": { "id": "U_c0EiwB88OG" }, "source": [ "Sidecar evaluation is supported only with a single task. This means:\n", "\n", "* It is guaranteed that each example is evaluated once. In the event the\n", " evaluator is preempted or restarted, it simply restarts the\n", " evaluation loop from the latest checkpoint, and the partial evaluation\n", " progress made before the restart is discarded.\n", "\n", "* However, running evaluation on a single task implies that a full evaluation\n", " can possibly take a long time.\n", "\n", "* If the size of the model is too large to fit into an evaluator's memory,\n", " single sidecar evaluation is not applicable." ] }, { "cell_type": "markdown", "metadata": { "id": "VNJoWVc797B1" }, "source": [ "Another caveat is that the `tf.keras.utils.SidecarEvaluator` implementation, and the custom\n", "evaluation loop below, may skip some checkpoints because it always picks up the\n", "latest checkpoint available, and during an evaluation epoch, multiple\n", "checkpoints can be produced from the training cluster. You can write a custom\n", "evaluation loop that evaluates every checkpoint, but it is not covered in this\n", "tutorial. On the other hand, it may sit idle if checkpoints are produced less\n", "frequently than how long it takes to run evaluation." ] }, { "cell_type": "markdown", "metadata": { "id": "G5jopxBd85Ji" }, "source": [ "A custom evaluation loop provides more control over the details, such as choosing which checkpoint to evaluate, or providing any additional logic to run along with evaluation. The following is a possible custom sidecar evaluation loop:\n", "\n", "```python\n", "checkpoint_dir = ...\n", "eval_model = ...\n", "eval_data = ...\n", "checkpoint = tf.train.Checkpoint(model=eval_model)\n", "\n", "for latest_checkpoint in tf.train.checkpoints_iterator(\n", " checkpoint_dir):\n", " try:\n", " checkpoint.restore(latest_checkpoint).expect_partial()\n", " except (tf.errors.OpError,) as e:\n", " # checkpoint may be deleted by training when it is about to read it.\n", " continue\n", "\n", " # Optionally add callbacks to write summaries.\n", " eval_model.evaluate(eval_data)\n", "\n", " # Evaluation finishes when it has evaluated the last epoch.\n", " if latest_checkpoint.endswith('-{}'.format(train_epochs)):\n", " break\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "9TkNbtpPhFRQ" }, "source": [ "## Clusters in the real world\n", "\n", "\n", "Note: this section is not necessary for running the tutorial code in this page.\n", "\n", "In a real production environment, you will run all tasks in different processes on different machines. The simplest way to configure cluster information on each task is to set `\"TF_CONFIG\"` environment variables and use a `tf.distribute.cluster_resolver.TFConfigClusterResolver` to parse `\"TF_CONFIG\"`.\n", "\n", "For a general description of `\"TF_CONFIG\"` environment variables, refer to \"Setting up the `TF_CONFIG` environment variable\" in the [Distributed training](../../guide/distributed_training.ipynb) guide.\n", "\n", "If you start your training tasks using Kubernetes or other configuration templates, likely, these templates have already set `“TF_CONFIG\"` for you." ] }, { "cell_type": "markdown", "metadata": { "id": "n7AK9SJGt3tQ" }, "source": [ "### Set the `\"TF_CONFIG\"` environment variable\n", "\n", "Suppose you have 3 workers and 2 parameter servers. Then the `\"TF_CONFIG\"` of worker 1 can be:\n", "\n", "```python\n", "os.environ[\"TF_CONFIG\"] = json.dumps({\n", " \"cluster\": {\n", " \"worker\": [\"host1:port\", \"host2:port\", \"host3:port\"],\n", " \"ps\": [\"host4:port\", \"host5:port\"],\n", " \"chief\": [\"host6:port\"]\n", " },\n", " \"task\": {\"type\": \"worker\", \"index\": 1}\n", "})\n", "```\n", "\n", "The `\"TF_CONFIG\"` of the evaluator can be:\n", "\n", "```python\n", "os.environ[\"TF_CONFIG\"] = json.dumps({\n", " \"cluster\": {\n", " \"evaluator\": [\"host7:port\"]\n", " },\n", " \"task\": {\"type\": \"evaluator\", \"index\": 0}\n", "})\n", "```\n", "\n", "The `\"cluster\"` part in the above `\"TF_CONFIG\"` string for the evaluator is optional." ] }, { "cell_type": "markdown", "metadata": { "id": "fZRjMS0pt1LM" }, "source": [ "### If you use the same binary for all tasks\n", "\n", "If you prefer to run all these tasks using a single binary, you will need to let your program branch into different roles at the very beginning:\n", "\n", "```python\n", "cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()\n", "if cluster_resolver.task_type in (\"worker\", \"ps\"):\n", " # Start a TensorFlow server and wait.\n", "elif cluster_resolver.task_type == \"evaluator\":\n", " # Run sidecar evaluation\n", "else:\n", " # Run the coordinator.\n", "```\n", "\n", "The following code starts a TensorFlow server and waits, useful for the `\"worker\"` and `\"ps\"` roles:\n", "\n", "```python\n", "# Set the environment variable to allow reporting worker and ps failure to the\n", "# coordinator. This is a workaround and won't be necessary in the future.\n", "os.environ[\"GRPC_FAIL_FAST\"] = \"use_caller\"\n", "\n", "server = tf.distribute.Server(\n", " cluster_resolver.cluster_spec(),\n", " job_name=cluster_resolver.task_type,\n", " task_index=cluster_resolver.task_id,\n", " protocol=cluster_resolver.rpc_layer or \"grpc\",\n", " start=True)\n", "server.join()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "ZWdYfK593eOL" }, "source": [ "## Handling task failure" ] }, { "cell_type": "markdown", "metadata": { "id": "Bl9eK5r13cOv" }, "source": [ "### Worker failure\n", "\n", "Both the `tf.distribute.coordinator.ClusterCoordinator` custom training loop and `Model.fit` approaches provide built-in fault tolerance for worker failure. Upon worker recovery, the `ClusterCoordinator` invokes dataset re-creation on the workers." ] }, { "cell_type": "markdown", "metadata": { "id": "aP0OHZ1-Ne-B" }, "source": [ "### Parameter server or coordinator failure\n", "\n", "However, when the coordinator sees a parameter server error, it will raise an `UnavailableError` or `AbortedError` immediately. You can restart the coordinator in this case. The coordinator itself can also become unavailable. Therefore, certain tooling is recommended in order to not lose the training progress:" ] }, { "cell_type": "markdown", "metadata": { "id": "f7m7Itoz8lsI" }, "source": [ "- For `Model.fit`, you should use a `BackupAndRestore` callback, which handles the progress saving and restoration automatically. See [Callbacks and training](#callbacks-and-training) section above for an example." ] }, { "cell_type": "markdown", "metadata": { "id": "-XlLyJp53Z8A" }, "source": [ "- For a custom training loop, you should checkpoint the model variables periodically and load model variables from a checkpoint, if any, before training starts. The training progress can be inferred approximately from `optimizer.iterations` if an optimizer is checkpointed:\n", "\n", "```python\n", "checkpoint_manager = tf.train.CheckpointManager(\n", " tf.train.Checkpoint(model=model, optimizer=optimizer),\n", " checkpoint_dir,\n", " max_to_keep=3)\n", "if checkpoint_manager.latest_checkpoint:\n", " checkpoint = checkpoint_manager.checkpoint\n", " checkpoint.restore(\n", " checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()\n", "\n", "global_steps = int(optimizer.iterations.numpy())\n", "starting_epoch = global_steps // steps_per_epoch\n", "\n", "for _ in range(starting_epoch, num_epochs):\n", " for _ in range(steps_per_epoch):\n", " coordinator.schedule(step_fn, args=(per_worker_iterator,))\n", " coordinator.join()\n", " checkpoint_manager.save()\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "PlN1P7C53XK9" }, "source": [ "### Fetching a `RemoteValue`\n", "\n", "Fetching a `RemoteValue` is guaranteed to succeed if a function is executed successfully. This is because currently the return value is immediately copied to the coordinator after a function is executed. If there is any worker failure during the copy, the function will be retried on another available worker. Therefore, if you want to optimize for performance, you can schedule functions without a return value." ] }, { "cell_type": "markdown", "metadata": { "id": "iZcR_xNZ3UdU" }, "source": [ "## Error reporting\n", "\n", "Once the coordinator sees an error such as `UnavailableError` from parameter servers or other application errors such as an `InvalidArgument` from `tf.debugging.check_numerics`, it will cancel all pending and queued functions before raising the error. Fetching their corresponding `RemoteValue`s will raise a `CancelledError`.\n", "\n", "After an error is raised, the coordinator will not raise the same error or any error from cancelled functions." ] }, { "cell_type": "markdown", "metadata": { "id": "QfhbXH-j3NVw" }, "source": [ "## Performance improvement\n", "\n", "There are several possible reasons you may experience performance issues when you train with `tf.distribute.ParameterServerStrategy` and `tf.distribute.coordinator.ClusterCoordinator`.\n", "\n", "One common reason is that the parameter servers have unbalanced load and some heavily-loaded parameter servers have reached capacity. There can also be multiple root causes. Some simple methods to mitigate this issue are to:\n", "\n", "1. Shard your large model variables via specifying a `variable_partitioner` when constructing a `ParameterServerStrategy`.\n", "2. Avoid creating a hotspot variable that is required by all parameter servers in a single step, by both:\n", "\n", " 1) Using a constant learning rate or subclass `tf.keras.optimizers.schedules.LearningRateSchedule` in optimizers. This is because the default behavior is that the learning rate will become a variable placed on a particular parameter server, and requested by all other parameter servers in each step); and\n", "\n", " 2) Using a `tf.keras.optimizers.legacy.Optimizer` (the standard `tf.keras.optimizers.Optimizer`s could still lead to hotspot variables).\n", "3. Shuffle your large vocabularies before passing them to Keras preprocessing layers.\n", "\n", "Another possible reason for performance issues is the coordinator. The implementation of `schedule`/`join` is Python-based and thus may have threading overhead. Also, the latency between the coordinator and the workers can be large. If this is the case:\n", "\n", "- For `Model.fit`, you can set the `steps_per_execution` argument provided at `Model.compile` to a value larger than 1.\n", "\n", "- For a custom training loop, you can pack multiple steps into a single `tf.function`:\n", "\n", "```python\n", "steps_per_invocation = 10\n", "\n", "@tf.function\n", "def step_fn(iterator):\n", " for _ in range(steps_per_invocation):\n", " features, labels = next(iterator)\n", " def replica_fn(features, labels):\n", " ...\n", "\n", " strategy.run(replica_fn, args=(features, labels))\n", "```\n", "\n", "As the library is optimized further, hopefully most users won't have to manually pack steps in the future.\n", "\n", "In addition, a small trick for performance improvement is to schedule functions without a return value as explained in the [handling task failure section](#handling_task_failure) above." ] }, { "cell_type": "markdown", "metadata": { "id": "chu5F7M_JmVk" }, "source": [ "## Known limitations\n", "\n", " \n", "\n", "Most of the known limitations are already covered in the above sections. This section provides a summary.\n", "\n", "### `ParameterServerStrategy` general\n", "\n", "- `os.environment[\"grpc_fail_fast\"]=\"use_caller\"` is needed on every task including the coordinator, to make fault tolerance work properly.\n", "- Synchronous parameter server training is not supported.\n", "- It is usually necessary to pack multiple steps into a single function to achieve optimal performance.\n", "- It is not supported to load a saved_model via `tf.saved_model.load` containing sharded variables. Note loading such a saved_model using TensorFlow Serving is expected to work (refer to the [serving tutorial](https://www.tensorflow.org/tfx/tutorials/serving/rest_simple) for details).\n", "- It is not supported to recover from parameter server failure without restarting the coordinator task.\n", "- Creation of `tf.lookup.StaticHashTable`, commonly employed by some Keras preprocessing layers, such as `tf.keras.layers.IntegerLookup`, `tf.keras.layers.StringLookup`, and `tf.keras.layers.TextVectorization`, should be placed under `Strategy.scope`. Otherwise, resources will be placed on the coordinator, and lookup RPCs from workers to the coordinator incur performance implications.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2MKBF0RPSvzB" }, "source": [ "### `Model.fit` specifics\n", "\n", "- `steps_per_epoch` argument is required in `Model.fit`. You can select a value that provides appropriate intervals in an epoch.\n", "- `ParameterServerStrategy` does not have support for custom callbacks that have batch-level calls for performance reasons. You should convert those calls into epoch-level calls with suitably picked `steps_per_epoch`, so that they are called every `steps_per_epoch` number of steps. Built-in callbacks are not affected: their batch-level calls have been modified to be performant. Supporting batch-level calls for `ParameterServerStrategy` is being planned.\n", "- For the same reason, unlike other strategies, progress bars and metrics are logged only at epoch boundaries.\n", "- `run_eagerly` is not supported.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "wvY-mg35Sx5L" }, "source": [ "### Custom training loop specifics\n", "\n", "- `ClusterCoordinator.schedule` doesn't support visitation guarantees for a dataset in general, although a visitation guarantee for evaluation is possible through `Model.fit/.evaluate`. See [Enabling exactly-once evaluation](#exactly_once_evaluation).\n", "- When `ClusterCoordinator.create_per_worker_dataset` is used with a callable as input, the whole dataset must be created inside the function passed to it.\n", "- `tf.data.Options` is ignored in a dataset created by `ClusterCoordinator.create_per_worker_dataset`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "parameter_server_training.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }