如何优化这段NumPy代码以提高速度?

huangapple go评论93阅读模式
英文:

How to optimize this numpy code to make it faster?

问题

这段代码正常工作,并计算了两个嵌入之间的余弦距离。但是它需要很长时间。我有数万条记录要检查,我正在寻找一种加快速度的方法。

import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm

import ast

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")

for k, i in enumerate(df["embeddings"]):
    df["dist" + str(k)] = df.embeddings.apply(
        lambda x: dot(ast.literal_eval(x), ast.literal_eval(i))
        / (norm(ast.literal_eval(x)) * norm(ast.literal_eval(i)))
    )
英文:

This code is working as expected and calculates cosign distance between two embeddings. But it takes a lot of time. I have tens of thousands of records to check and I am looking for a way to make it quicker.

import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm

import ast

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")

for k, i in enumerate(df["embeddings"]):
    df["dist" + str(k)] = df.embeddings.apply(
        lambda x: dot(ast.literal_eval(x), ast.literal_eval(i))
        / (norm(ast.literal_eval(x)) * norm(ast.literal_eval(i)))
    )

答案1

得分: 2

优化的方法

不要在循环中多次应用ast.literal_eval,而是使用converters选项一次性加载具有所需结构的输入CSV文件,将所有'embeddings'列的数组字符串表示转换为"real"数组,使用numpy.fromstring例程。

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv", 
                 delimiter=',', usecols=["embeddings"], quotechar='"',
                 converters={"embeddings": lambda s: np.fromstring(s[1:-1], sep=',')})

在上述过程中,根据我的测量结果,numpy.fromstring的运行速度比converters = {"embeddings": ast.literal_eval}快6倍,尽管后者比您的初始方法更快。

然后,为了避免多次插入数据帧以添加新的dist<num>列,用pd.concat替换for循环:

df = pd.concat([df] + [df.embeddings.apply(
                        lambda x: np.dot(x, arr)
                                  / (np.linalg.norm(x) * np.linalg.norm(arr))
                        ).rename(f'dist{i}')
                        for i, arr in enumerate(df["embeddings"])], axis=1)

最终结果(前3条记录的片段):

print(df.head(3))
英文:

Optimized way:

Instead of applying ast.literal_eval many times in a loop, load the input csv file with the needed structure at once using converters option to convert all &#39;embeddings&#39; column array string representation into "real" array with numpy.fromstring routine.

df = pd.read_csv(&quot;https://testme162.s3.amazonaws.com/cosign_dist.csv&quot;, 
                 delimiter=&#39;,&#39;, usecols=[&quot;embeddings&quot;], quotechar=&#39;&quot;&#39;,
                 converters = {&#39;embeddings&#39;: lambda s: np.fromstring(s[1:-1], sep=&#39;,&#39;)})

In the above process in my measurements numpy.fromstring runs in 6 times faster than converters = {&#39;embeddings&#39;: ast.literal_eval}, though the latter, in it's turn, is faster than your initial approach.

Then, to avoid multiple dataframe insertions with a new dist&lt;num&gt; column replace for looping with pd.concat:

df = pd.concat([df] + [df.embeddings.apply(
                        lambda x: np.dot(x, arr)
                                  / (np.linalg.norm(x) * np.linalg.norm(arr))
                        ).rename(f&#39;dist{i}&#39;)
                        for i, arr in enumerate(df[&quot;embeddings&quot;])], axis=1)

The final result (a fragment of first 3 records):

print(df.head(3))

                                          embeddings     dist0     dist1  \
0  [-0.009409046731889248, 0.01787922903895378, -...  1.000000  0.824427   
1  [-0.0005574452807195485, -0.004265215713530779...  0.824427  1.000000   
2  [-0.024396933615207672, -0.0016798048745840788...  0.757717  0.762072   

      dist2     dist3     dist4     dist5     dist6     dist7     dist8  \
0  0.757717  0.761481  0.858895  0.844244  0.781320  0.830562  0.869494   
1  0.762072  0.768355  0.832918  0.813206  0.779384  0.822365  0.831671   
2  1.000000  0.775206  0.757655  0.756076  0.770092  0.766206  0.765154   

      dist9    dist10    dist11    dist12    dist13    dist14    dist15  \
0  0.824993  0.838671  0.863087  0.809240  0.839480  0.852663  0.812553   
1  0.843859  0.832757  0.846339  0.797901  0.833095  0.836512  0.794878   
2  0.756694  0.765615  0.760532  0.754582  0.759305  0.749540  0.758749   

     dist16    dist17    dist18    dist19    dist20    dist21    dist22  \
0  0.834376  0.851168  0.853374  0.831500  0.786812  0.840630  0.831902   
1  0.831882  0.829072  0.828278  0.773624  0.828781  0.814124  0.852540   
2  0.749419  0.750785  0.759364  0.776753  0.761560  0.770000  0.766988   

     dist23    dist24    dist25    dist26    dist27    dist28    dist29  \
0  0.836812  0.807903  0.822919  0.837386  0.767737  0.815725  0.807334   
1  0.799753  0.827189  0.812533  0.822119  0.788155  0.850503  0.843936   
2  0.761825  0.773761  0.764308  0.757826  0.755465  0.772704  0.766396   

     dist30    dist31    dist32    dist33    dist34    dist35    dist36  \
0  0.832612  0.835909  0.819697  0.853597  0.806614  0.805309  0.822521   
1  0.852967  0.842627  0.802803  0.860669  0.793716  0.787563  0.788239   
2  0.762748  0.763906  0.765716  0.756643  0.766686  0.772603  0.760913   

     dist37    dist38    dist39    dist40    dist41    dist42    dist43  \
0  0.831307  0.834015  0.821262  0.812144  0.853028  0.849498  0.830675   
1  0.845437  0.816868  0.833320  0.808172  0.835293  0.824654  0.856051   
2  0.760276  0.754683  0.765499  0.756421  0.755651  0.763656  0.754828   

     dist44    dist45    dist46    dist47    dist48  ...    dist50    dist51  \
0  0.861366  0.802735  0.789774  0.790563  0.827335  ...  0.820754  0.842522   
1  0.854080  0.827517  0.839423  0.828683  0.812323  ...  0.802451  0.829247   
2  0.760256  0.764869  0.754423  0.757319  0.774664  ...  0.747934  0.793632   

     dist52    dist53    dist54    dist55    dist56    dist57    dist58  \
0  0.827061  0.814656  0.813548  0.834271  0.818362  0.823394  0.828642   
1  0.814514  0.834007  0.784510  0.796033  0.821271  0.821276  0.814710   
2  0.759410  0.747319  0.783079  0.759875  0.742791  0.771096  0.759520   

     dist59    dist60    dist61    dist62    dist63    dist64    dist65  \
0  0.869624  0.840927  0.842052  0.859140  0.859804  0.840041  0.835204   
1  0.835696  0.845089  0.810699  0.853660  0.834497  0.828624  0.803920   
2  0.764160  0.758037  0.773802  0.762592  0.762257  0.751729  0.758366   

     dist66    dist67    dist68    dist69    dist70    dist71    dist72  \
0  0.816945  0.852561  0.815066  0.812858  0.844518  0.851627  0.838417   
1  0.821947  0.812763  0.765442  0.795368  0.848876  0.831772  0.828389   
2  0.759480  0.755786  0.762572  0.756787  0.769603  0.756226  0.750196   

     dist73    dist74    dist75    dist76    dist77    dist78    dist79  \
0  0.839868  0.846972  0.851668  0.860816  0.880957  0.845313  0.849569   
1  0.822491  0.810707  0.812499  0.816586  0.828081  0.826785  0.813240   
2  0.757696  0.746333  0.767805  0.759218  0.770810  0.766181  0.768756   

     dist80    dist81    dist82    dist83    dist84    dist85    dist86  \
0  0.870180  0.862554  0.866397  0.874742  0.899475  0.883464  0.879084   
1  0.848113  0.840173  0.814944  0.826645  0.848822  0.818360  0.809330   
2  0.763216  0.766606  0.762598  0.754603  0.767628  0.757145  0.774004   

     dist87    dist88    dist89    dist90    dist91    dist92    dist93  \
0  0.861392  0.874843  0.855589  0.851598  0.849689  0.854272  0.837288   
1  0.827020  0.839443  0.822301  0.831517  0.815193  0.827057  0.813251   
2  0.771965  0.768978  0.784956  0.768604  0.767573  0.759978  0.772354   

     dist94    dist95    dist96    dist97    dist98    dist99  
0  0.841660  0.868675  0.867444  0.836115  0.829863  0.834038  
1  0.799496  0.837142  0.833741  0.791625  0.819392  0.807420  
2  0.767823  0.770422  0.756819  0.762370  0.774629  0.777811  

[3 rows x 101 columns]

答案2

得分: 1

我正在尝试通过创建一个数据表格来提供解决方案,用于存储名为embeddings的数据,然后使用scipy.spatial.distance库来计算距离,df3是您期望的结果。

import pandas as pd
import numpy as np
from scipy.spatial import distance
from numpy.linalg import norm
import ast
import time 

start = time.perf_counter()

df = pd.read_csv("https://testme162.s3.amazonaws.com/cosign_dist.csv")

num_colums = len(ast.literal_eval(df["embeddings"][0]))
num_lines = len(df["embeddings"])
data_base_embeddings = np.zeros((num_lines, num_colums))

colums_list = []

for k, i in enumerate(df["embeddings"]):
    embeddings_temp = np array(ast.literal_eval(i))
    data_base_embeddings[k, :] = embeddings_temp
    colums_list.append("dist" + str(k))

dist_matrix = distance.cdist(data_base_embeddings, data_base_embeddings, lambda u, v: np.dot(u, v) / (norm(u) * norm(v))

df2 = pd.DataFrame(data=dist_matrix, columns=colums_list)

df3 = df.join(df2)

end = time.perf_counter()

print("运行时间:", str(end - start), "秒")

这将为您提供包括加载数据在内的运行时间:

运行时间: 2.0602212000012514 秒

不包括加载数据时的运行时间:

运行时间: 0.7129389000001538 秒

英文:

I'm trying to give a solution with creating a database for the data embeddings, then calculate the distance with the library scipy.spatial.distance, the df3 is the result you expected.

import pandas as pd
import numpy as np
from scipy.spatial import distance
from numpy.linalg import norm
import ast
import time 


start = time.perf_counter()

df = pd.read_csv(&quot;https://testme162.s3.amazonaws.com/cosign_dist.csv&quot;)

num_colums = len(ast.literal_eval(df[&quot;embeddings&quot;][0]))
num_lines = len(df[&quot;embeddings&quot;])
data_base_embeddings =np.zeros((num_lines, num_colums))

colums_list = []

for k, i in enumerate(df[&quot;embeddings&quot;]):
    embeddings_temp = np.array(ast.literal_eval(i))
    data_base_embeddings[k, :] = embeddings_temp
    colums_list.append(&quot;dist&quot; + str(k))
    


dist_matix = distance.cdist(data_base_embeddings, data_base_embeddings, lambda u, v: np.dot(u,v) / (norm(u)*norm(v)))

df2 = pd.DataFrame(data=dist_matix, columns=colums_list)

df3 = df.join(df2)

end = time.perf_counter()

print(&quot;Running time: &quot;, str(end - start), &quot;s&quot;)

Which gives you a running time including loading data :

> Running time: 2.0602212000012514 s

With not including loading data:

> Running time: 0.7129389000001538 s

huangapple
  • 本文由 发表于 2023年2月27日 17:25:42
  • 转载请务必保留本文链接:https://go.coder-hub.com/75578679.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定