Github项目推荐 | 基于PyTorch以用户为中心的可微概率推理包Brancher
A user-centered Python package for differentiable probabilistic inference.
Brancher是一个以用户为中心的Python包,用于可区分的概率推理。
Site:https://brancher.org/
Github项目地址:
Brancher允许使用随机变分推理来设计和训练可微分贝叶斯模型。 Brancher基于深度学习框架PyTorch。
Brancher的特点:
灵活:易于扩展的建模框架,GPU加速的PyTorch后端
集成:易于使用的现代工具箱,支持Pandas和Seaborn
直观:易于学习具有类似数学语法的符号界面
入门教程
通过以下教程在Google Colab中学习Brancher(更多内容即将推出!)
安装
安装PyTorch后,可以从PyPI安装Brancher:
或者直接克隆github项目:https://github.com/AI-DI/Brancher
建立概率模型
概率模型是象征性定义的。 随机变量可以创建如下:
a = NormalVariable(loc = 0., scale = 1., name = 'a')
b = NormalVariable(loc = 0., scale = 1., name = 'b')
可以使用算术和数学函数将随机变量链接在一起:
c = NormalVariable(loc = a**2 + BF.sin(b),
scale = BF.exp(b),
name = 'a')
通过这种方式,可以创建任意复杂的概率模型。 也可以使用PyTorch的所有深度学习工具来定义具有深度神经网络的概率模型。
示例:自回归建模
概率模型
概率模型是象征性地定义的:
T = 20driving_noise = 1.
measure_noise = 0.3x0 = NormalVariable(0., driving_noise, 'x0')
y0 = NormalVariable(x0, measure_noise, 'x0')
b = LogitNormalVariable(0.5, 1., 'b')
x = [x0]
y = [y0]
x_names = ["x0"]
y_names = ["y0"]for t in range(1,T):
x_names.append("x{}".format(t))
y_names.append("y{}".format(t))
x.append(NormalVariable(b*x[t-1], driving_noise, x_names[t]))
y.append(NormalVariable(x[t], measure_noise, y_names[t]))AR_model = ProbabilisticModel(x + y)
观察数据
一旦定义了概率模型,我们就可以决定观察哪个变量:
[yt.observe(data[yt][:, 0, :]) for yt in y]
自回归变分分布
变分分布可以是任意结构:
Qb = LogitNormalVariable(0.5, 0.5, "b", learnable=True)
logit_b_post = DeterministicVariable(0., 'logit_b_post', learnable=True)
Qx = [NormalVariable(0., 1., 'x0', learnable=True)]
Qx_mean = [DeterministicVariable(0., 'x0_mean', learnable=True)]for t in range(1, T):
Qx_mean.append(DeterministicVariable(0., x_names[t] + "_mean", learnable=True))
Qx.append(NormalVariable(BF.sigmoid(logit_b_post)*Qx[t-1] + Qx_mean[t], 1., x_names[t], learnable=True))
variational_posterior = ProbabilisticModel([Qb] + Qx)
model.set_posterior_model(variational_posterior)
推理
现在模型被激活了,我们可以使用随机梯度下降来执行近似推断:
number_iterations=500,
number_samples=300,
optimizer="SGD",
lr=0.001)
点这里查看本篇更多相关内容
最新评论
推荐文章
作者最新文章
你可能感兴趣的文章
Copyright Disclaimer: The copyright of contents (including texts, images, videos and audios) posted above belong to the User who shared or the third-party website which the User shared from. If you found your copyright have been infringed, please send a DMCA takedown notice to [email protected]. For more detail of the source, please click on the button "Read Original Post" below. For other communications, please send to [email protected].
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。
版权声明:以上内容为用户推荐收藏至CareerEngine平台,其内容(含文字、图片、视频、音频等)及知识版权均属用户或用户转发自的第三方网站,如涉嫌侵权,请通知[email protected]进行信息删除。如需查看信息来源,请点击“查看原文”。如需洽谈其它事宜,请联系[email protected]。