英文:
Split spark dataset based on number of rows
问题
我正在从 DynamoDB 中读取数据,并将其存储在 Spark
的 Dataset
中,代码如下:
// 构建数据集
Dataset citations = sparkSession.read()
.option("tableName", "Covid19Citation")
.option("region", "eu-west-1")
.format("dynamodb")
.load();
我想要根据行数拆分这个数据集。例如,如果数据集中有超过 500 行,我想要将其拆分并将每个数据集保存为单独的 CSV 文件。因此,我想要保存的每个数据集最多包含 500 行。例如,如果数据库中有 1600 行,输出应该是四个 XML 文件:
第一个 XML 文件包含 500 行,
第二个 XML 文件也包含 500 行,
第三个 XML 文件也包含 500 行,最后
第四个 XML 文件包含 100 行。
这是我迄今为止尝试过的方法,但没有成功:
List<Dataset> datasets = new ArrayList<>();
while (citations.count() > 0) {
Dataset splitted = citations.limit(400);
datasets.add(splitted);
citations = citations.except(splitted);
}
System.out.println("datasets : " + datasets.size());
for (Dataset d : datasets) {
// 你的其他代码
d.coalesce(1)
.write()
.format("com.databricks.spark.xml")
.option("rootTag", "citations")
.option("rowTag", "citation")
.mode("overwrite")
.save("s3a://someoutputfolder/");
}
任何帮助将不胜感激。谢谢!
英文:
I'm reading data from dynamo db and storing it in Spark
Dataset
like this :
// Building a dataset
Dataset citations = sparkSession.read()
.option("tableName", "Covid19Citation")
.option("region", "eu-west-1")
.format("dynamodb")
.load();
What I want is to split this dataset based on number of rows.
For example, if dataset has more than 500 rows, I want to split it and save each of the datasets which I get as separate csv file. So, each of dataset which I want to save should have maximum 500 rows. Eg. If I there was 1600 rows in database, the output shoud be four xml files :
First xml file contains 500 rows,
Second xml file which contains also 500 rows,
Third xml file which contains also 500 rows, and finally
Fourth xml file which contains 100 rows.
This is what I tried so far, but this doesn't work:
List<Dataset> datasets = new ArrayList<>();
while (citations.count() > 0) {
Dataset splitted = citations.limit(400);
datasets.add(splitted);
citations = citations.except(splitted);
}
System.out.println("datasets : " + datasets.size());
for (Dataset d : datasets) {
code
d.coalesce(1)
.write()
.format("com.databricks.spark.xml")
.option("rootTag", "citations")
.option("rowTag", "citation")
.mode("overwrite")
.save("s3a://someoutputfolder/");
}
Any help would be highly appreciated.
Thanks
答案1
得分: 1
你可以利用以下方法:
row_number
和mod
:将数据集拆分为每个部分500个repartition
:为每个分区生成一个文件partitionBy
:为每个分区编写一个XML文件
以下是在Scala / Parquet中的示例(但您也可以使用xml
):
val citations = spark.range(1, 2000000).selectExpr("id", "hash(id) value")
// 计算桶的数量
val total = citations.count
val mod = (total.toFloat / 500).ceil.toInt
citations
.withColumn("id", expr("row_number() over(order by monotonically_increasing_id())"))
.withColumn("bucket", expr(f"mod(id, ${mod})"))
.repartition('bucket)
.write
.partitionBy("bucket")
.format("parquet")
.mode("overwrite")
.save("/tmp/foobar")
// 现在检查结果
val resultDf = spark.read.format("parquet").load("/tmp/foobar")
// 结果最多包含500行
scala> resultDf.groupBy("bucket").count.show
+------+-----+
|bucket|count|
+------+-----+
| 1133| 500|
| 1771| 500|
| 1890| 500|
| 3207| 500|
| 3912| 500|
| 1564| 500|
| 2823| 500|
+------+-----+
// 没有超过500行的文件
scala> resultDf.groupBy("bucket").count.filter("count > 500").show
+------+-----+
|bucket|count|
+------+-----+
+------+-----+
// 现在检查每个桶中只有一个文件
scala> spark.sparkContext.parallelize(resultDf.inputFiles).toDF
.withColumn("part", expr("regexp_extract(value,'(bucket=([0-9]+))')"))
.groupBy("part").count.withColumnRenamed("count", "nb_files")
.orderBy(desc("nb_files")).show(5)
+-----------+--------+
| part|nb_files|
+-----------+--------+
|bucket=3209| 1|
|bucket=1290| 1|
|bucket=3354| 1|
|bucket=2007| 1|
|bucket=2816| 1|
+-----------+--------+
英文:
You can leverage :
row_number
andmod
: to split the dataset into parts of 500repartition
: to produce one file per partitionpartitionBy
: to write one xml per partition
Here an example in scala / parquet (but you can use xml
as well)
val citations = spark.range(1, 2000000).selectExpr("id", "hash(id) value")
// calculate the number of buckets
val total = citations.count
val mod = (total.toFloat / 500).ceil.toInt
citations
.withColumn("id", expr("row_number() over(order by monotonically_increasing_id())"))
.withColumn("bucket", expr(f"mod(id, ${mod})"))
.repartition('bucket)
.write
.partitionBy("bucket")
.format("parquet")
.mode("overwrite")
.save("/tmp/foobar")
// now check the results
val resultDf = spark.read.format("parquet").load("/tmp/foobar")
// as a result you get at most 500 rows
scala> resultDf.groupBy("bucket").count.show
+------+-----+
|bucket|count|
+------+-----+
| 1133| 500|
| 1771| 500|
| 1890| 500|
| 3207| 500|
| 3912| 500|
| 1564| 500|
| 2823| 500|
+------+-----+
// there is no file with more than 500 rows
scala> resultDf.groupBy("bucket").count.filter("count > 500").show
+------+-----+
|bucket|count|
+------+-----+
+------+-----+
// now check there is only one file per bucket
scala> spark.sparkContext.parallelize(resultDf.inputFiles).toDF
.withColumn("part", expr("regexp_extract(value,'(bucket=([0-9]+))')"))
.groupBy("part").count.withColumnRenamed("count", "nb_files")
.orderBy(desc("nb_files")).show(5)
+-----------+--------+
| part|nb_files|
+-----------+--------+
|bucket=3209| 1|
|bucket=1290| 1|
|bucket=3354| 1|
|bucket=2007| 1|
|bucket=2816| 1|
+-----------+--------+
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论