Spark Join Optimisation

Last time, we discussed how to find the bottle neck from Spark UI, but actually, many performance issues in our daily Spark jobs are related to data skew. There are many situations where data skew occurs:

  1. When loading an RDD, the data in each partition is uneven by default. For example, we often use date as one of the partition keys in our daily work, and some days might be holidays, which could lead to a sudden increase in data consumption.
  2. When different RDDs need to join, the join key as the partition key can lead to uneven data distribution between partitions. This is a common reason in our daily trouble shooting.

The first problem is actually relatively easy to solve. A common and straightforward way is to directly repartition. The second problem, however, which usually happens during a lot of SQL join operations, is more tricky to deal with. To clarify it, we need to start from the principle of join.

Generally, all database join principles can be summarized into few categories, and here we mainly discuss two that we encounter most frequently in our daily work:

  1. HashJoin
  2. BroadcastJoin
  3. Sort-Merge Join

HashJoin

This principle of hash join is provided on the official MySQL website. The example here is select * from countries, persons where countries.id = persons.country_id. Generally, hash join consists of three steps:

  1. Determine the Build Table (mapping table, small table) and the Probe Table (probe table, large table). The Build Table is used to construct a Hash Table, and the Probe Table will traverse all its keys to match them with the generated Hash Table.
  2. The Build Table constructs the Hash Table. Sequentially read data from Build Table (country), and for each row of data, hash according to the join key (item.id) to the corresponding Bucket to generate a record in the hash table. Data is cached in memory.
  3. The Probe Table scan. Sequentially scan the data of the Probe Table (order), use the same hash function to map the records in the Hash Table, successful mapping will then check the join condition (item.id= order.i_id), and if the match is successful, the two can be joined together.

BroadcastJoin

In Spark world, if you were the designer at that time, how would you do it? You must be able to think that we can send this small table to the data in each partition of the big table for hash join, right? Yes, this is the principle of Spark’s broadCastJoin. From here we can know that if data skew occurs here, it must be that the partition of the large table itself is unreasonable, because there is no partition in the small table, and each replica will be sent to each partition of the big table. So, a simple repartition on the big table can actually solve most problems. If your table is too large and the burden of repartition is too heavy, you may use our ultimate solution mentioned below, salt key.

SortMergeJoin

But there is a premise for the above broadcast join, that is, one of the tables is small. But in many cases, both of our tables are big. How does Spark handle this situation? As shown in the graph below:

1. Shuffle stage: Repartition the two large tables according to the join key. In this example, we partition according to customer_id, and the data of the two tables will be distributed across the entire cluster for distributed parallel processing.

2. Sort stage: Sort the data of the two tables in a single partition node. Spark has already done the relevant work by default, so we won’t expand on it here.

3. Merge stage: Perform a join operation on the sorted data of the two partition tables. The join operation is simple, just traverse the two sorted sequences, which you should all be very familiar with from the leetcode.

After explaining the principles of several classic joins, what should we do when we encounter data skewness related to join?

Let’s start with the most primitive solution. Actually, our most basic need is to make data as evenly distributed as possible for parallel processing in the system. Suppose we have two tables, one is a Hello table, and the other is a World table, select * from hello, world on hello.id = world.id. (Excuse my poor drawing)

Assume that hello is the big table here, and world is the small one. We can see that if we join directly, the three pieces of data will all be calculated in one task, while the other task only needs to handle one piece of data. This is a typical data skew. To solve this problem, we add two steps before join:

1.  Add random prefixes a, b, c to the id of the hello table;

2.  Add random prefixes a, b, c to each piece of data in the world table, which means it has increased 3 times.

Then we carry on with join, and at this point we find that we will have at least 4 tasks dealing with these data by default, and there is only one piece of data in each task for processing, which is a lot more balanced. Moreover, every piece of data that has been prefixed randomly in Hello can find the corresponding World, because we have expanded World.

In the end of the processing, we remove the prefix and reduce it according to the previous normal logic.

This is actually the famous salt key technique. Here we actually have an implicit prerequisite that the world table is small, so we have expanded the world table under this assumption condition. But what if the world table is also a big one? Actually we just need to be flexible and observe what the hot keys are, and then expand those hot keys. For example, above we can only expand world of id 1.

The above techniques are actually very practical in our daily work. Combined with our previous simple and practical repartition, we can basically solve 80% of performance bottleneck problems caused by data skew.