获取列中数组的相关矩阵

huangapple go评论60阅读模式
英文:

Get correlation matrix for array in a column

问题

我理解你想要的是计算相关性矩阵,交叉id列,不同的天,根据交叉的数量来填充矩阵,如果标签与自身交叉则填充0。你的期望输出如下:

+---+-----+---+---+---+---+---+---+
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  0|
|  2|   t3|  2|  2|  0|  0|  0|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+

这可以通过Spark DataFrame进行计算。以下是一种可能的方法:

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 创建一个窗口,用于在同一天的标签之间进行交叉
window_spec = Window.partitionBy("day")

# 创建一个新的列,包含交叉的id
for i in range(1, 7):
    sdf = sdf.withColumn(f"t{i}", F.when(F.col("label") != f"t{i}", F.size(F.array_intersect("id", F.collect_list("id").over(window_spec))))).fillna(0)

# 选择并重命名列
sdf = sdf.select("day", "label", "t1", "t2", "t3", "t4", "t5", "t6")

sdf.show()

这会生成你所期望的相关性矩阵。请确保你的集群有足够的内存来处理大型数据集,以避免内存问题。

英文:

I have dataframe:

data = [['t1', ['u1','u2', 'u3', 'u4', 'u5'], 1],['t2', ['u1','u7', 'u8', 'u5'], 1], ['t3', ['u1','u2', 'u7', 'u11'], 2], ['t4', ['u8','u9'], 3], ['t5', ['u9','u22', 'u11'], 3],
       ['t6', ['u5','u11', 'u22', 'u4'], 3]]
sdf = spark.createDataFrame(data, schema=['label', 'id', 'day'])
sdf.show()
+-----+--------------------+---+
|label|                  id|day|
+-----+--------------------+---+
|   t1|[u1, u2, u3, u4, u5]|  1|
|   t2|    [u1, u7, u8, u5]|  1|
|   t3|   [u1, u2, u7, u11]|  2|
|   t4|            [u8, u9]|  3|
|   t5|      [u9, u22, u11]|  3|
|   t6|  [u5, u11, u22, u4]|  3|
+-----+--------------------+---+

I want to calculate the correlation matrix (actually my dataframe is much larger):

I would like to cross the id column every other day. That is, on day=1, I dont cross IDs in that day, and set 0 for such cases. I cross the first day with the second and third, etc.

Moreover, if the label intersects with itself, then there is not 100, but 0 is given (the diagonal is 0).

And in the matrix, I would like to record the absolute value of the intersection ( how many IDs have intersected)
It should probably turn out such a dataframe:

+---+-----+---+---+---+---+---+---+
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  0|
|  2|   t3|  2|  2|  0|  0|  0|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+

And since I actually have a large dataset, I would like it not to require too much from memory and the task does not fall

答案1

得分: 5

以下是代码部分的翻译:

首先,您可以使用 explode 函数将 ID 列表展开:

>>> from pyspark.sql.functions import explode
>>> from pyspark.sql.types import StructType, StructField, StringType, ArrayType
>>> schema = StructType([
...     StructField('label', StringType(), nullable=False),
...     StructField('ids', ArrayType(StringType(), containsNull=False), nullable=false),
...     StructField('day', StringType(), nullable=false),
... ])
>>> data = [
...     ['t1', ['u1', 'u2', 'u3', 'u4', 'u5'], 1],
...     ['t2', ['u1', 'u7', 'u8', 'u5'], 1],
...     ['t3', ['u1', 'u2', 'u7', 'u11'], 2],
...     ['t4', ['u8', 'u9'], 3],
...     ['t5', ['u9', 'u22', 'u11'], 3],
...     ['t6', ['u5', 'u11', 'u22', 'u4'], 3]
... ]
>>> id_lists_df = spark.createDataFrame(data, schema=schema)
>>> df = id_lists_df.select('label', 'day', explode('ids').alias('id'))
>>> df.show()
+-----+---+---+                                                                 
|label|day| id|
+-----+---+---+
|   t1|  1| u1|
|   t1|  1| u2|
|   t1|  1| u3|
|   t1|  1| u4|
|   t1|  1| u5|
|   t2|  1| u1|
|   t2|  1| u7|
|   t2|  1| u8|
|   t2|  1| u5|
|   t3|  2| u1|
|   t3|  2| u2|
|   t3|  2| u7|
|   t3|  2|u11|
|   t4|  3| u8|
|   t4|  3| u9|
|   t5|  3| u9|
|   t5|  3|u22|
|   t5|  3|u11|
|   t6|  3| u5|
|   t6|  3|u11|
+-----+---+---+
只显示前20行

然后,您可以对结果数据帧进行自连接,过滤掉不需要的行(相同的日期或标签),然后继续进行实际计数。

我有印象您的矩阵将包含很多零值。
您是否需要一个“物理”矩阵,或者一天和标签对的计数就足够了?

如果您不需要一个“物理”矩阵,可以使用常规聚合(按天和标签分组,然后计数):

>>> df2 = df.withColumnRenamed('label', 'label2').withColumnRenamed('day', 'day2')
>>> counts = df.join(df2, on='id') \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label, df2.label2) \
...     .count() \
...     .orderBy(df.label, df2.label2)
>>> 
>>> counts.show()
+---+-----+------+-----+
|day|label|label2|count|
+---+-----+------+-----+
|  1|   t1|    t3|    2|
|  1|   t1|    t6|    2|
|  1|   t2|    t3|    2|
|  1|   t2|    t4|    1|
|  1|   t2|    t6|    1|
|  2|   t3|    t1|    2|
|  2|   t3|    t2|    2|
|  2|   t3|    t5|    1|
|  2|   t3|    t6|    1|
|  3|   t4|    t2|    1|
|  3|   t5|    t3|    1|
|  3|   t6|    t1|    2|
|  3|   t6|    t2|    1|
|  3|   t6|    t3|    1|
+---+-----+------+-----+

如果您需要“物理”矩阵,您可以按照第一个答案中建议的使用 MLlib 进行处理,或者您可以在 label2 上使用 pivot 而不是将其用作分组列:

>>> counts_pivoted = df.join(df2, on='id') \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label) \
...     .pivot('label2') \
...     .count() \
...     .drop('label2') \
...     .orderBy('label') \
...     .fillna(0)
>>> counts_pivoted.show()
+---+-----

<details>
<summary>英文:</summary>

First of all, you can use `explode` to flatten the lists of IDs:

```python
&gt;&gt;&gt; from pyspark.sql.functions import explode
&gt;&gt;&gt; from pyspark.sql.types import StructType, StructField, StringType, ArrayType
&gt;&gt;&gt; schema = StructType([
...     StructField(&#39;label&#39;, StringType(), nullable=False),
...     StructField(&#39;ids&#39;, ArrayType(StringType(), containsNull=False), nullable=False),
...     StructField(&#39;day&#39;, StringType(), nullable=False),
... ])
&gt;&gt;&gt; data = [
...     [&#39;t1&#39;, [&#39;u1&#39;, &#39;u2&#39;, &#39;u3&#39;, &#39;u4&#39;, &#39;u5&#39;], 1],
...     [&#39;t2&#39;, [&#39;u1&#39;, &#39;u7&#39;, &#39;u8&#39;, &#39;u5&#39;], 1],
...     [&#39;t3&#39;, [&#39;u1&#39;, &#39;u2&#39;, &#39;u7&#39;, &#39;u11&#39;], 2],
...     [&#39;t4&#39;, [&#39;u8&#39;, &#39;u9&#39;], 3],
...     [&#39;t5&#39;, [&#39;u9&#39;, &#39;u22&#39;, &#39;u11&#39;], 3],
...     [&#39;t6&#39;, [&#39;u5&#39;, &#39;u11&#39;, &#39;u22&#39;, &#39;u4&#39;], 3]
... ]
&gt;&gt;&gt; id_lists_df = spark.createDataFrame(data, schema=schema)
&gt;&gt;&gt; df = id_lists_df.select(&#39;label&#39;, &#39;day&#39;, explode(&#39;ids&#39;).alias(&#39;id&#39;))
&gt;&gt;&gt; df.show()
+-----+---+---+                                                                 
|label|day| id|
+-----+---+---+
|   t1|  1| u1|
|   t1|  1| u2|
|   t1|  1| u3|
|   t1|  1| u4|
|   t1|  1| u5|
|   t2|  1| u1|
|   t2|  1| u7|
|   t2|  1| u8|
|   t2|  1| u5|
|   t3|  2| u1|
|   t3|  2| u2|
|   t3|  2| u7|
|   t3|  2|u11|
|   t4|  3| u8|
|   t4|  3| u9|
|   t5|  3| u9|
|   t5|  3|u22|
|   t5|  3|u11|
|   t6|  3| u5|
|   t6|  3|u11|
+-----+---+---+
only showing top 20 rows

Then you can self-join the resulting data frame, filter out the unwanted rows (same day or label) and then proceed to the actual counting.

I have the impression that your matrix will contain lots of zeros.
Do you need a "physical" matrix or is a count per day and pair of labels sufficient?

If you don't need a "physical" matrix, you can use regular aggregations (group by day and labels and then count):

&gt;&gt;&gt; df2 = df.withColumnRenamed(&#39;label&#39;, &#39;label2&#39;).withColumnRenamed(&#39;day&#39;, &#39;day2&#39;)
&gt;&gt;&gt; counts = df.join(df2, on=&#39;id&#39;) \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label, df2.label2) \
...     .count() \
...     .orderBy(df.label, df2.label2)
&gt;&gt;&gt; 
&gt;&gt;&gt; counts.show()
+---+-----+------+-----+                                                        
|day|label|label2|count|
+---+-----+------+-----+
|  1|   t1|    t3|    2|
|  1|   t1|    t6|    2|
|  1|   t2|    t3|    2|
|  1|   t2|    t4|    1|
|  1|   t2|    t6|    1|
|  2|   t3|    t1|    2|
|  2|   t3|    t2|    2|
|  2|   t3|    t5|    1|
|  2|   t3|    t6|    1|
|  3|   t4|    t2|    1|
|  3|   t5|    t3|    1|
|  3|   t6|    t1|    2|
|  3|   t6|    t2|    1|
|  3|   t6|    t3|    1|
+---+-----+------+-----+

&gt;&gt;&gt; counts.explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Sort [label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST], true, 0
   +- Exchange rangepartitioning(label#0 ASC NULLS FIRST, label2#493 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2010]
      +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
         +- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2007]
            +- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
               +- Project [label#0, day#2, label2#493]
                  +- SortMergeJoin [id#7], [id#504], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
                     :- Sort [id#7 ASC NULLS FIRST], false, 0
                     :  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=1999]
                     :     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
                     :        +- Filter (size(ids#1, true) &gt; 0)
                     :           +- Scan ExistingRDD[label#0,ids#1,day#2]
                     +- Sort [id#504 ASC NULLS FIRST], false, 0
                        +- Exchange hashpartitioning(id#504, 200), ENSURE_REQUIREMENTS, [plan_id=2000]
                           +- Project [label#501 AS label2#493, day#503 AS day2#497, id#504]
                              +- Generate explode(ids#502), [label#501, day#503], false, [id#504]
                                 +- Filter (size(ids#502, true) &gt; 0)
                                    +- Scan ExistingRDD[label#501,ids#502,day#503]

If you need "physical" matrices, you can work with MLlib as suggested in the first answer, or you can use pivot on label2 instead of using it as a grouping column:

&gt;&gt;&gt; counts_pivoted = df.join(df2, on=&#39;id&#39;) \
...     .where(df.label != df2.label2) \
...     .where(df.day != df2.day2) \
...     .groupby(df.day, df.label) \
...     .pivot(&#39;label2&#39;) \
...     .count() \
...     .drop(&#39;label2&#39;) \
...     .orderBy(&#39;label&#39;) \
...     .fillna(0)
&gt;&gt;&gt; counts_pivoted.show()                                                       
+---+-----+---+---+---+---+---+---+                                             
|day|label| t1| t2| t3| t4| t5| t6|
+---+-----+---+---+---+---+---+---+
|  1|   t1|  0|  0|  2|  0|  0|  2|
|  1|   t2|  0|  0|  2|  1|  0|  1|
|  2|   t3|  2|  2|  0|  0|  1|  1|
|  3|   t4|  0|  1|  0|  0|  0|  0|
|  3|   t5|  0|  0|  1|  0|  0|  0|
|  3|   t6|  2|  1|  1|  0|  0|  0|
+---+-----+---+---+---+---+---+---+
&gt;&gt;&gt; counts_pivoted.explain()
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [day#2, label#0, coalesce(t1#574L, 0) AS t1#616L, coalesce(t2#575L, 0) AS t2#617L, coalesce(t3#576L, 0) AS t3#618L, coalesce(t4#577L, 0) AS t4#619L, coalesce(t5#578L, 0) AS t5#620L, coalesce(t6#579L, 0) AS t6#621L]
+- Sort [label#0 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(label#0 ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=2744]
+- Project [day#2, label#0, __pivot_count(1) AS count AS `count(1) AS count`#573[0] AS t1#574L, __pivot_count(1) AS count AS `count(1) AS count`#573[1] AS t2#575L, __pivot_count(1) AS count AS `count(1) AS count`#573[2] AS t3#576L, __pivot_count(1) AS count AS `count(1) AS count`#573[3] AS t4#577L, __pivot_count(1) AS count AS `count(1) AS count`#573[4] AS t5#578L, __pivot_count(1) AS count AS `count(1) AS count`#573[5] AS t6#579L]
+- HashAggregate(keys=[day#2, label#0], functions=[pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
+- Exchange hashpartitioning(day#2, label#0, 200), ENSURE_REQUIREMENTS, [plan_id=2740]
+- HashAggregate(keys=[day#2, label#0], functions=[partial_pivotfirst(label2#493, count(1) AS count#559L, t1, t2, t3, t4, t5, t6, 0, 0)])
+- HashAggregate(keys=[day#2, label#0, label2#493], functions=[count(1)])
+- Exchange hashpartitioning(day#2, label#0, label2#493, 200), ENSURE_REQUIREMENTS, [plan_id=2736]
+- HashAggregate(keys=[day#2, label#0, label2#493], functions=[partial_count(1)])
+- Project [label#0, day#2, label2#493]
+- SortMergeJoin [id#7], [id#543], Inner, (NOT (label#0 = label2#493) AND NOT (day#2 = day2#497))
:- Sort [id#7 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(id#7, 200), ENSURE_REQUIREMENTS, [plan_id=2728]
:     +- Generate explode(ids#1), [label#0, day#2], false, [id#7]
:        +- Filter (size(ids#1, true) &gt; 0)
:           +- Scan ExistingRDD[label#0,ids#1,day#2]
+- Sort [id#543 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(id#543, 200), ENSURE_REQUIREMENTS, [plan_id=2729]
+- Project [label#540 AS label2#493, day#542 AS day2#497, id#543]
+- Generate explode(ids#541), [label#540, day#542], false, [id#543]
+- Filter (size(ids#541, true) &gt; 0)
+- Scan ExistingRDD[label#540,ids#541,day#542]

The values are not completely identical to your example, but I assume that werner's comment is correct.

The pivot variant is probably less efficient. If the list of possible labels is available beforehand, you can save some time by passing it as the second argument of pivot.

答案2

得分: 0

I'll try and write up a code sample soon, but here's some background you need to understand what the solution is.

  • You need to understand Sparks Vectors.
  • As you seem to be using strings you need String Indexer.
  • Maybe use:
  • vector assembler to convert your dataframe into vectors.
  • Or write a function to convert sparse/dense vectors.
  • Then You can then use its MLib correlation function.

I'll write something up soon.

英文:

I'll try and write up a code sample soon, but here's some background you need to understand what the solution is.

  • You need to understand Sparks Vectors.
  • As you seem to be using strings you need String Indexer.
  • Maybe use:
  • vector assembler to convert you dataframe into vectors.
  • Or write a function to convert sparse/dense vecorts.
  • Then You can then use it's MLib correlation function.

I'll write something up soon.

huangapple
  • 本文由 发表于 2023年3月20日 22:47:05
  • 转载请务必保留本文链接:https://go.coder-hub.com/75791761.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定