
Product
Introducing Tier 1 Reachability: Precision CVE Triage for Enterprise Teams
Socket’s new Tier 1 Reachability filters out up to 80% of irrelevant CVEs, so security teams can focus on the vulnerabilities that matter.
spark-tensorflow-distributor
Advanced tools
This package helps users do distributed training with TensorFlow on their Spark clusters.
This package helps users do distributed training with TensorFlow on their Spark clusters.
This package requires Python 3.6+, tensorflow>=2.1.0
and pyspark>=3.0.0
to run.
To install spark-tensorflow-distributor
, run:
pip install spark-tensorflow-distributor
The installation does not install PySpark because for most users, PySpark is already installed.
In addition, tensorflow
not installed so that users may choose between regular and CPU-only
installation via pip install tensorflow
and pip install tensorflow-cpu
.
If you do not have PySpark installed, you can install it directly:
pip install pyspark>=3.0.*
Note also that in order to use many features of this package, you must set up Spark custom resource scheduling for GPUs on your cluster. See the Spark docs for this.
For integration tests, first build the master and worker images and then run the test script.
docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
./tests/integration/run.sh
For linting, run the following.
./tests/lint.sh
To use the autoformatter, run the following.
yapf --recursive --in-place spark_tensorflow_distributor
Run following example code in pyspark
shell:
from spark_tensorflow_distributor import MirroredStrategyRunner
# Adapted from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
def train():
import tensorflow as tf
import uuid
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def make_datasets():
(mnist_images, mnist_labels), _ = \
tf.keras.datasets.mnist.load_data(path=str(uuid.uuid4())+'mnist.npz')
dataset = tf.data.Dataset.from_tensor_slices((
tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
return dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'],
)
return model
train_datasets = make_datasets()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_datasets = train_datasets.with_options(options)
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
MirroredStrategyRunner(num_slots=8).run(train)
FAQs
This package helps users do distributed training with TensorFlow on their Spark clusters.
We found that spark-tensorflow-distributor demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 3 open source maintainers collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Product
Socket’s new Tier 1 Reachability filters out up to 80% of irrelevant CVEs, so security teams can focus on the vulnerabilities that matter.
Research
/Security News
Ongoing npm supply chain attack spreads to DuckDB: multiple packages compromised with the same wallet-drainer malware.
Security News
The MCP Steering Committee has launched the official MCP Registry in preview, a central hub for discovering and publishing MCP servers.