Chapter 9. PyTorch Distributed Machine Learning Approach

PyTorch is an open source machine learning library developed by Facebook’s AI Research (FAIR) team and later donated to the Linux Foundation. It was designed to simplify the creation of artificial neural networks and enable applications such as computer vision, natural language processing, and more. The primary interface to PyTorch is Python, but it’s built on low-level C and C++ code. This is a very different approach from Spark, which uses Scala and Java (JVM-based programming languages) at its core.

In the previous chapters, you’ve learned about the building blocks of the machine learning workflow. We started with Spark, then expanded to explore TensorFlow’s distributed training capabilities. In this chapter, we will turn our attention to PyTorch. The goal is to help you better understand what PyTorch is and how its distributed machine learning training works, from an architectural and conceptual perspective, so you can make better decisions when combining multiple frameworks together in a distributed setting.

We will also go through a step-by-step example of working with distributed PyTorch while leveraging the previous work we did with Spark in Chapters 4 and 5 and Petastorm in Chapter 7.

This chapter covers the following:

  • A quick overview of PyTorch basics

  • PyTorch distributed strategies for training models

  • How to load Parquet data into PyTorch

  • Putting it all together—from Petastorm to a model with PyTorch

Get Scaling Machine Learning with Spark now with the O’Reilly learning platform.

O’Reilly members experience books, live events, courses curated by job role, and more from O’Reilly and nearly 200 top publishers.