如何优化这段NumPy代码以提高速度?

huangapple go评论121阅读模式
英文:

How to optimize this numpy code to make it faster?

问题

这段代码正常工作,并计算了两个嵌入之间的余弦距离。但是它需要很长时间。我有数万条记录要检查,我正在寻找一种加快速度的方法。

  1. import pandas as pd
  2. import numpy as np
  3. from numpy import dot
  4. from numpy.linalg import norm
  5. import ast
  6. df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")
  7. for k, i in enumerate(df["embeddings"]):
  8. df["dist" + str(k)] = df.embeddings.apply(
  9. lambda x: dot(ast.literal_eval(x), ast.literal_eval(i))
  10. / (norm(ast.literal_eval(x)) * norm(ast.literal_eval(i)))
  11. )
英文:

This code is working as expected and calculates cosign distance between two embeddings. But it takes a lot of time. I have tens of thousands of records to check and I am looking for a way to make it quicker.

  1. import pandas as pd
  2. import numpy as np
  3. from numpy import dot
  4. from numpy.linalg import norm
  5. import ast
  6. df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")
  7. for k, i in enumerate(df["embeddings"]):
  8. df["dist" + str(k)] = df.embeddings.apply(
  9. lambda x: dot(ast.literal_eval(x), ast.literal_eval(i))
  10. / (norm(ast.literal_eval(x)) * norm(ast.literal_eval(i)))
  11. )

答案1

得分: 2

优化的方法

不要在循环中多次应用ast.literal_eval,而是使用converters选项一次性加载具有所需结构的输入CSV文件,将所有'embeddings'列的数组字符串表示转换为"real"数组,使用numpy.fromstring例程。

  1. df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv",
  2. delimiter=',', usecols=["embeddings"], quotechar='"',
  3. converters={"embeddings": lambda s: np.fromstring(s[1:-1], sep=',')})

在上述过程中,根据我的测量结果,numpy.fromstring的运行速度比converters = {"embeddings": ast.literal_eval}快6倍,尽管后者比您的初始方法更快。

然后,为了避免多次插入数据帧以添加新的dist<num>列,用pd.concat替换for循环:

  1. df = pd.concat([df] + [df.embeddings.apply(
  2. lambda x: np.dot(x, arr)
  3. / (np.linalg.norm(x) * np.linalg.norm(arr))
  4. ).rename(f'dist{i}')
  5. for i, arr in enumerate(df["embeddings"])], axis=1)

最终结果(前3条记录的片段):

  1. print(df.head(3))
英文:

Optimized way:

Instead of applying ast.literal_eval many times in a loop, load the input csv file with the needed structure at once using converters option to convert all &#39;embeddings&#39; column array string representation into "real" array with numpy.fromstring routine.

  1. df = pd.read_csv(&quot;https://testme162.s3.amazonaws.com/cosign_dist.csv&quot;,
  2. delimiter=&#39;,&#39;, usecols=[&quot;embeddings&quot;], quotechar=&#39;&quot;&#39;,
  3. converters = {&#39;embeddings&#39;: lambda s: np.fromstring(s[1:-1], sep=&#39;,&#39;)})

In the above process in my measurements numpy.fromstring runs in 6 times faster than converters = {&#39;embeddings&#39;: ast.literal_eval}, though the latter, in it's turn, is faster than your initial approach.

Then, to avoid multiple dataframe insertions with a new dist&lt;num&gt; column replace for looping with pd.concat:

  1. df = pd.concat([df] + [df.embeddings.apply(
  2. lambda x: np.dot(x, arr)
  3. / (np.linalg.norm(x) * np.linalg.norm(arr))
  4. ).rename(f&#39;dist{i}&#39;)
  5. for i, arr in enumerate(df[&quot;embeddings&quot;])], axis=1)

The final result (a fragment of first 3 records):

  1. print(df.head(3))

  1. embeddings dist0 dist1 \
  2. 0 [-0.009409046731889248, 0.01787922903895378, -... 1.000000 0.824427
  3. 1 [-0.0005574452807195485, -0.004265215713530779... 0.824427 1.000000
  4. 2 [-0.024396933615207672, -0.0016798048745840788... 0.757717 0.762072
  5. dist2 dist3 dist4 dist5 dist6 dist7 dist8 \
  6. 0 0.757717 0.761481 0.858895 0.844244 0.781320 0.830562 0.869494
  7. 1 0.762072 0.768355 0.832918 0.813206 0.779384 0.822365 0.831671
  8. 2 1.000000 0.775206 0.757655 0.756076 0.770092 0.766206 0.765154
  9. dist9 dist10 dist11 dist12 dist13 dist14 dist15 \
  10. 0 0.824993 0.838671 0.863087 0.809240 0.839480 0.852663 0.812553
  11. 1 0.843859 0.832757 0.846339 0.797901 0.833095 0.836512 0.794878
  12. 2 0.756694 0.765615 0.760532 0.754582 0.759305 0.749540 0.758749
  13. dist16 dist17 dist18 dist19 dist20 dist21 dist22 \
  14. 0 0.834376 0.851168 0.853374 0.831500 0.786812 0.840630 0.831902
  15. 1 0.831882 0.829072 0.828278 0.773624 0.828781 0.814124 0.852540
  16. 2 0.749419 0.750785 0.759364 0.776753 0.761560 0.770000 0.766988
  17. dist23 dist24 dist25 dist26 dist27 dist28 dist29 \
  18. 0 0.836812 0.807903 0.822919 0.837386 0.767737 0.815725 0.807334
  19. 1 0.799753 0.827189 0.812533 0.822119 0.788155 0.850503 0.843936
  20. 2 0.761825 0.773761 0.764308 0.757826 0.755465 0.772704 0.766396
  21. dist30 dist31 dist32 dist33 dist34 dist35 dist36 \
  22. 0 0.832612 0.835909 0.819697 0.853597 0.806614 0.805309 0.822521
  23. 1 0.852967 0.842627 0.802803 0.860669 0.793716 0.787563 0.788239
  24. 2 0.762748 0.763906 0.765716 0.756643 0.766686 0.772603 0.760913
  25. dist37 dist38 dist39 dist40 dist41 dist42 dist43 \
  26. 0 0.831307 0.834015 0.821262 0.812144 0.853028 0.849498 0.830675
  27. 1 0.845437 0.816868 0.833320 0.808172 0.835293 0.824654 0.856051
  28. 2 0.760276 0.754683 0.765499 0.756421 0.755651 0.763656 0.754828
  29. dist44 dist45 dist46 dist47 dist48 ... dist50 dist51 \
  30. 0 0.861366 0.802735 0.789774 0.790563 0.827335 ... 0.820754 0.842522
  31. 1 0.854080 0.827517 0.839423 0.828683 0.812323 ... 0.802451 0.829247
  32. 2 0.760256 0.764869 0.754423 0.757319 0.774664 ... 0.747934 0.793632
  33. dist52 dist53 dist54 dist55 dist56 dist57 dist58 \
  34. 0 0.827061 0.814656 0.813548 0.834271 0.818362 0.823394 0.828642
  35. 1 0.814514 0.834007 0.784510 0.796033 0.821271 0.821276 0.814710
  36. 2 0.759410 0.747319 0.783079 0.759875 0.742791 0.771096 0.759520
  37. dist59 dist60 dist61 dist62 dist63 dist64 dist65 \
  38. 0 0.869624 0.840927 0.842052 0.859140 0.859804 0.840041 0.835204
  39. 1 0.835696 0.845089 0.810699 0.853660 0.834497 0.828624 0.803920
  40. 2 0.764160 0.758037 0.773802 0.762592 0.762257 0.751729 0.758366
  41. dist66 dist67 dist68 dist69 dist70 dist71 dist72 \
  42. 0 0.816945 0.852561 0.815066 0.812858 0.844518 0.851627 0.838417
  43. 1 0.821947 0.812763 0.765442 0.795368 0.848876 0.831772 0.828389
  44. 2 0.759480 0.755786 0.762572 0.756787 0.769603 0.756226 0.750196
  45. dist73 dist74 dist75 dist76 dist77 dist78 dist79 \
  46. 0 0.839868 0.846972 0.851668 0.860816 0.880957 0.845313 0.849569
  47. 1 0.822491 0.810707 0.812499 0.816586 0.828081 0.826785 0.813240
  48. 2 0.757696 0.746333 0.767805 0.759218 0.770810 0.766181 0.768756
  49. dist80 dist81 dist82 dist83 dist84 dist85 dist86 \
  50. 0 0.870180 0.862554 0.866397 0.874742 0.899475 0.883464 0.879084
  51. 1 0.848113 0.840173 0.814944 0.826645 0.848822 0.818360 0.809330
  52. 2 0.763216 0.766606 0.762598 0.754603 0.767628 0.757145 0.774004
  53. dist87 dist88 dist89 dist90 dist91 dist92 dist93 \
  54. 0 0.861392 0.874843 0.855589 0.851598 0.849689 0.854272 0.837288
  55. 1 0.827020 0.839443 0.822301 0.831517 0.815193 0.827057 0.813251
  56. 2 0.771965 0.768978 0.784956 0.768604 0.767573 0.759978 0.772354
  57. dist94 dist95 dist96 dist97 dist98 dist99
  58. 0 0.841660 0.868675 0.867444 0.836115 0.829863 0.834038
  59. 1 0.799496 0.837142 0.833741 0.791625 0.819392 0.807420
  60. 2 0.767823 0.770422 0.756819 0.762370 0.774629 0.777811
  61. [3 rows x 101 columns]

答案2

得分: 1

我正在尝试通过创建一个数据表格来提供解决方案,用于存储名为embeddings的数据,然后使用scipy.spatial.distance库来计算距离,df3是您期望的结果。

  1. import pandas as pd
  2. import numpy as np
  3. from scipy.spatial import distance
  4. from numpy.linalg import norm
  5. import ast
  6. import time
  7. start = time.perf_counter()
  8. df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")
  9. num_colums = len(ast.literal_eval(df["embeddings"][0]))
  10. num_lines = len(df["embeddings"])
  11. data_base_embeddings = np.zeros((num_lines, num_colums))
  12. colums_list = []
  13. for k, i in enumerate(df["embeddings"]):
  14. embeddings_temp = np array(ast.literal_eval(i))
  15. data_base_embeddings[k, :] = embeddings_temp
  16. colums_list.append("dist" + str(k))
  17. dist_matrix = distance.cdist(data_base_embeddings, data_base_embeddings, lambda u, v: np.dot(u, v) / (norm(u) * norm(v))
  18. df2 = pd.DataFrame(data=dist_matrix, columns=colums_list)
  19. df3 = df.join(df2)
  20. end = time.perf_counter()
  21. print("运行时间:", str(end - start), "秒")

这将为您提供包括加载数据在内的运行时间:

运行时间: 2.0602212000012514 秒

不包括加载数据时的运行时间:

运行时间: 0.7129389000001538 秒

英文:

I'm trying to give a solution with creating a database for the data embeddings, then calculate the distance with the library scipy.spatial.distance, the df3 is the result you expected.

  1. import pandas as pd
  2. import numpy as np
  3. from scipy.spatial import distance
  4. from numpy.linalg import norm
  5. import ast
  6. import time
  7. start = time.perf_counter()
  8. df = pd.read_csv(&quot;https://testme162.s3.amazonaws.com/cosign_dist.csv&quot;)
  9. num_colums = len(ast.literal_eval(df[&quot;embeddings&quot;][0]))
  10. num_lines = len(df[&quot;embeddings&quot;])
  11. data_base_embeddings =np.zeros((num_lines, num_colums))
  12. colums_list = []
  13. for k, i in enumerate(df[&quot;embeddings&quot;]):
  14. embeddings_temp = np.array(ast.literal_eval(i))
  15. data_base_embeddings[k, :] = embeddings_temp
  16. colums_list.append(&quot;dist&quot; + str(k))
  17. dist_matix = distance.cdist(data_base_embeddings, data_base_embeddings, lambda u, v: np.dot(u,v) / (norm(u)*norm(v)))
  18. df2 = pd.DataFrame(data=dist_matix, columns=colums_list)
  19. df3 = df.join(df2)
  20. end = time.perf_counter()
  21. print(&quot;Running time: &quot;, str(end - start), &quot;s&quot;)

Which gives you a running time including loading data :

> Running time: 2.0602212000012514 s

With not including loading data:

> Running time: 0.7129389000001538 s

huangapple
  • 本文由 发表于 2023年2月27日 17:25:42
  • 转载请务必保留本文链接:https://go.coder-hub.com/75578679.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定