当前位置: 首页> 游戏> 游戏 > NeuralForecast 推理 - 最简单的推理方式

NeuralForecast 推理 - 最简单的推理方式

时间:2025/7/10 18:10:56来源:https://blog.csdn.net/flyfish1986/article/details/139413310 浏览次数:0次

NeuralForecast 推理 - 最简单的推理方式

flyfish

最简单的保存和加载模型代码

import pandas as pd
import numpy as npAirPassengers = np.array([112.0, 118.0, 132.0, 129.0, 121.0, 135.0, 148.0, 148.0, 136.0, 119.0],dtype=np.float32,
)AirPassengersDF = pd.DataFrame({"unique_id": np.ones(len(AirPassengers)),"ds": pd.date_range(start="1949-01-01", periods=len(AirPassengers), freq=pd.offsets.MonthEnd()),"y": AirPassengers,}
)Y_df = AirPassengersDF
Y_df = Y_df.reset_index(drop=True)
Y_df.head()
#Model Trainingfrom neuralforecast.core import NeuralForecast
from neuralforecast.models import NBEATShorizon = 2
models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50)]nf = NeuralForecast(models=models, freq='M')
nf.fit(df=Y_df)#Save models
nf.save(path='./checkpoints/test_run/',model_index=None, overwrite=True,save_dataset=True)#Load models
nf2 = NeuralForecast.load(path='./checkpoints/test_run/')
Y_hat_df = nf2.predict().reset_index()
Y_hat_df.head()

简单的预测

import numpy as np
from neuralforecast.core import NeuralForecast
from neuralforecast.models import NBEATS# 新的输入数据
new_data = pd.DataFrame({"unique_id": [1.0, 1.0],"ds": pd.to_datetime(["1949-01-31", "1949-02-28"]),"y": [112.0, 118.0],}
)# 确保数据的顺序和索引是正确的
new_data = new_data.reset_index(drop=True)
print("New input data:")
print(new_data)# 加载已保存的模型
nf2 = NeuralForecast.load(path='./checkpoints/test_run/')# 使用已加载的模型进行预测
Y_hat_df = nf2.predict(df=new_data).reset_index()
print("Prediction results:")
print(Y_hat_df)

.reset_index() 的作用如下:

重置索引:将 DataFrame 的索引重置为默认的整数索引。默认情况下,DataFrame 的索引可以是行标签,但有时候需要将其重置为默认的整数索引。
转换索引为列:如果索引是有意义的数据,可以选择将索引转换为 DataFrame 的一列数据。

.reset_index() 方法有几个常用参数:
drop:布尔值。如果为 True,则会删除索引列而不是将其转换为数据列。
inplace:布尔值。如果为 True,则会在原地修改 DataFrame 而不是返回一个新的 DataFrame。

日期索引被重置为默认的整数索引,并且原来的索引变成了 DataFrame 的一列

示例代码

import pandas as pddata = {'value': [10, 20, 30, 40]
}
index = pd.date_range(start='2022-01-01', periods=4, freq='D')
df = pd.DataFrame(data, index=index)
print("Original DataFrame:")
print(df)df_reset = df.reset_index()
print("\nDataFrame after reset_index:")
print(df_reset)

结果

Original DataFrame:value
2022-01-01     10
2022-01-02     20
2022-01-03     30
2022-01-04     40DataFrame after reset_index:index  value
0 2022-01-01     10
1 2022-01-02     20
2 2022-01-03     30
3 2022-01-04     40
关键字:NeuralForecast 推理 - 最简单的推理方式

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: