英文:
Improve multiple fitting Lorenztian on python
问题
我目前正在尝试使用scipy中的curve_fit将10个洛伦兹分布拟合到一个光谱上。我刚刚定义了一个函数,该函数是10个洛伦兹分布和一个背景偏移的总和。我的代码能够运行,但我认为这不是最佳的方法,因为它看起来不够简洁,计算大约需要4分钟。
我想寻找一种新的方法来进行多重拟合,因为我想改进这种工作方式。
以下是我的主要函数。我定义了"p0"值,以帮助"curve_fit"函数,然后绘制了图形。
我认为有更好的方法来定义我的函数,我真的想听听一些新的想法!
编辑 这里是我的数据和完整的.py文件的链接。
https://github.com/MaloBriend/ucl/tree/1039abe797e0031a876093c20513c02c841a4fae
英文:
I'm currently trying to fit 10 lorentzian on a spectrum with curve_fit from scipy. I just define a function which is the sum of 10 lorenztian + a background offset. My code is working but I think this is not the best way to do it because it doesn't seem pretty and it takes about 4 min to compute.
I'm asking for new way to do multiple fitting because I want to improve this type of work.
def lorentzian(x, amp1, cen1, wid1, amp2, cen2, wid2, amp3, cen3, wid3,\
amp4, cen4, wid4, amp5, cen5, wid5, amp6, cen6, wid6,\
amp7, cen7, wid7, amp8, cen8, wid8, amp9, cen9, wid9,\
amp10, cen10, wid10, a):
return (amp1*wid1**2/((x-cen1)**2+wid1**2)) +\
(amp2*wid2**2/((x-cen2)**2+wid2**2)) +\
(amp3*wid3**2/((x-cen3)**2+wid3**2)) +\
(amp4*wid4**2/((x-cen4)**2+wid4**2)) +\
(amp5*wid5**2/((x-cen5)**2+wid5**2)) +\
(amp6*wid6**2/((x-cen6)**2+wid6**2)) +\
(amp7*wid7**2/((x-cen7)**2+wid7**2)) +\
(amp8*wid8**2/((x-cen8)**2+wid8**2)) +\
(amp9*wid9**2/((x-cen9)**2+wid9**2)) +\
(amp10*wid10**2/((x-cen10)**2+wid10**2)) + a
p0 = np.concatenate((_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, [a]))
popt, pcov = curve_fit(lorentzian, freq, filtre, p0=p0)
fit = lorentzian(freq, *popt)
Here is my main function. I define by and the p0 values to help curve_fit and then I plot it.
I think there are better way to define my function and I really want to here some new ideas !
EDIT Here is a link to my data and my entire .py file.
https://github.com/MaloBriend/ucl/tree/1039abe797e0031a876093c20513c02c841a4fae
答案1
得分: 1
请注意,以下是您提供的代码的翻译部分:
-
"Do not copy-and-paste an expression ten times; use vectorization."
请不要复制粘贴表达式十次;使用向量化。 -
"This should not take four minutes! It should (and does) take less than a second when properly implemented."
这不应该需要四分钟!如果正确实施,它应该(而且确实)在不到一秒钟内完成。 -
"You need to be plotting semilog and not linear scale."
你需要绘制半对数图而不是线性刻度。 -
"The following is a semi-realistic demonstration of fitting where we do a first-pass peak finding with knowledge of the noise floor."
以下是一个半现实的拟合演示,我们在此进行了首次峰值查找,并了解了噪声水平。 -
"Parameter table (amplitude, centre, width):"
参数表(振幅,中心,宽度):
以上是您提供的代码的翻译部分。如果您有任何其他翻译需求,请随时告诉我。
英文:
Do not copy-and-paste an expression ten times; use vectorization.
This should not take four minutes! It should (and does) take less than a second when properly implemented.
You need to be plotting semilog and not linear scale.
The following is a semi-realistic demonstration of fitting where we do a first-pass peak finding with knowledge of the noise floor.
import matplotlib.pyplot as plt
import numpy as np
import scipy.signal
from numpy.random import default_rng
from scipy.optimize import curve_fit
F_RESOLUTION = 1e6
def synthesize() -> tuple[np.ndarray, np.ndarray]:
rand = default_rng(seed=0)
f = np.arange(start=-1.7e9, stop=1.8e9, step=F_RESOLUTION)
width = 5e6 + rand.uniform(low=-1e6, high=1e6)
amp = np.array(( 0.005, 0.005, 0.26, 0.02, 0.02, 0.001, 0.02, 0.14, 0.02, 0.04))
centre = np.array((-1.6e9, -1.3e9, -1e9, -0.5e9, -0.2e9, 0.3e9, 0.6e9, 0.9e9, 1.4e9, 1.7e9))
pure = (
amp*width**2 / ((f[:, np.newaxis] - centre)**2 + width**2)
).sum(axis=1)
noisy = np.clip(
pure + rand.uniform(low=-2e-4, high=2e-4, size=f.size),
a_min=0, a_max=None,
)
return f, noisy
def get_peaks(f: np.ndarray, spectrum: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
peak_idx, _ = scipy.signal.find_peaks(
x=spectrum,
height=5e-4,
distance=100e6/F_RESOLUTION,
)
peak_order = np.argsort(spectrum[peak_idx])
peak_idx = peak_idx[peak_order][-10:]
peak_freq = peak_idx*F_RESOLUTION + f.min()
peak_amp = spectrum[peak_idx]
return peak_freq, peak_amp
def lorentz(f: np.ndarray, *args: float) -> np.ndarray:
amp, centre, width = np.array(args).reshape((3, -1))
spectrum = (
amp*width**2 / ((f[:, np.newaxis] - centre)**2 + width**2)
)
return spectrum.sum(axis=1)
def make_guess(amp0: np.ndarray, centre0: np.ndarray) -> np.ndarray:
width0 = np.full_like(centre0, 4e6)
return np.concatenate((amp0, centre0, width0))
def fit(f: np.ndarray, spectrum: np.ndarray, guess: np.ndarray) -> np.ndarray:
n = guess.size//3
params, _ = curve_fit(
f=lorentz,
xdata=f,
ydata=spectrum,
p0=guess,
bounds=(
((0,)*n + (f.min(),)*n + (1e3,)*n),
((1,)*n + (f.max(),)*n + (1e8,)*n),
),
)
return params.reshape((3, -1))
def plot(f: np.ndarray, spectrum: np.ndarray, guess: np.ndarray, params: np.ndarray) -> plt.Figure:
ax: plt.Axes
fig, ax = plt.subplots()
ax.scatter(f, spectrum, c='blue', s=1, label='Experiment')
ax.semilogy(f, lorentz(f, guess), c='green', label='Guess')
ax.semilogy(f, lorentz(f, params), c='orange', label='Fit')
ax.set_xlabel('Frequency')
ax.set_ylabel('Cavity transmission (V)')
ax.set_ybound(lower=1e-5)
ax.legend()
return fig
def main() -> None:
f, spectrum = synthesize()
peak_freq, peak_amp = get_peaks(f, spectrum)
guess = make_guess(peak_amp, peak_freq)
params = fit(f, spectrum, guess)
print('Parameter table (amplitude, centre, width):')
print(params.T)
plot(f, spectrum, guess, params)
plt.show()
if __name__ == '__main__':
main()
Parameter table (amplitude, centre, width):
[[ 9.91063242e-04 2.99997567e+08 5.11840090e+06]
[ 5.02728448e-03 -1.60009375e+09 5.38846566e+06]
[ 4.88591262e-03 -1.29998945e+09 5.52068207e+06]
[ 1.99150343e-02 1.40001165e+09 5.29239119e+06]
[ 1.99581243e-02 -4.99993715e+08 5.27934769e+06]
[ 1.98832569e-02 -2.00027858e+08 5.31880182e+06]
[ 1.99795151e-02 5.99996444e+08 5.28726278e+06]
[ 4.00437270e-02 1.70000304e+09 5.26566084e+06]
[ 1.40041881e-01 9.00000992e+08 5.27369662e+06]
[ 2.59989040e-01 -1.00000048e+09 5.27544352e+06]]
答案2
得分: 0
如果您只关心获得一个“好”的拟合,那么您可以使用插值器来拟合数据。我使用了scipy.interpolate.splrep
和scipy.interpolate.splev
,因为它允许进行一些平滑处理(使用s
关键字参数)。我包含了一小部分平滑处理,因为数据仍然具有许多小波动。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import splrep, splev
from scipy.signal import savgol_filter
plt.close("all")
file = "CSV1605FSR.csv"
def readcsv(file, header=11):
data = pd.read_csv(file, header=header)
time = np.array(data["Second"])
ramp = np.array(data["Volt"])
mode = np.array(data["Volt.1"])
filtre = savgol_filter(mode, 301, 2)
return time, ramp, mode, filtre
time, ramp, mode, filtre = readcsv(file)
tck = splrep(time, filtre, s=0.01)
interpolated = splev(time, tck)
fig, ax = plt.subplots(figsize=(13,8))
ax.plot(time, filtre)
ax.plot(time, interpolated)
如果您只想要峰值的位置,那么您可以使用scipy.signal.find_peaks
。您需要包括一些参数,以便它不会捕捉到每一个小波动。使用height=0.001
表示峰值至少应该是0.001的高度(从0开始),width=300
表示峰值的基底应该有300个点宽,distance=5000
表示峰值之间至少应该有5000个点的间隔。(我认为只使用width
或distance
对于这个问题已足够。)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter, find_peaks
plt.close("all")
file = "CSV1605FSR.csv"
def readcsv(file, header=11):
data = pd.read_csv(file, header=header)
time = np.array(data["Second"])
ramp = np.array(data["Volt"])
mode = np.array(data["Volt.1"])
filtre = savgol_filter(mode, 301, 2)
return time, ramp, mode, filtre
time, ramp, mode, filtre = readcsv(file)
peaks = find_peaks(filtre, height=0.001, width=300, distance=5000)
fig, ax = plt.subplots(figsize=(13,8))
ax.plot(time, filtre)
ax.plot(time[peaks[0]], filtre[peaks[0]], "x")
英文:
If all you care about is getting a "good" fit, then you can use an interpolator to fit the data. I used scipy.interpolate.splrep
with scipy.interpolate.splev
since it allows for some smoothing (using the s
keyword argument). I included a small amount of smoothing since the data still has lots of little bumps.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import splrep, splev
from scipy.signal import savgol_filter
plt.close("all")
file = "CSV1605FSR.csv"
def readcsv(file,header=11):
data = pd.read_csv(file, header=header)
time = np.array(data["Second"])
ramp = np.array(data["Volt"])
mode = np.array(data["Volt.1"])
filtre = savgol_filter(mode, 301, 2)
return time, ramp, mode, filtre
time, ramp, mode, filtre = readcsv(file)
tck = splrep(time, filtre, s=0.01)
interpolated = splev(time, tck)
fig, ax = plt.subplots(figsize=(13,8))
ax.plot(time, filtre)
ax.plot(time, interpolated)
If you just want the locations of the peaks, then you can use scipy.signal.find_peaks
. You need to include some parameters so it doesn't pick up every little bump. Using height=0.001
says that the peaks should be at least 0.001 in height (from 0), width=300
says the base of the peak should be 300 points wide, and distance=5000
says that the peaks should be at least 5000 points apart. (I believe using just width
or distance
is sufficient for this problem.)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter, find_peaks
plt.close("all")
file = "CSV1605FSR.csv"
def readcsv(file,header=11):
data = pd.read_csv(file, header=header)
time = np.array(data["Second"])
ramp = np.array(data["Volt"])
mode = np.array(data["Volt.1"])
filtre = savgol_filter(mode, 301, 2)
return time, ramp, mode, filtre
time, ramp, mode, filtre = readcsv(file)
peaks = find_peaks(filtre, height=0.001, width=300, distance=5000)
fig, ax = plt.subplots(figsize=(13,8))
ax.plot(time, filtre)
ax.plot(time[peaks[0]], filtre[peaks[0]], "x")
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论