ML: PINN代码实例-TVD格式
本文尝试使用PINN并调用TVD条件来求解一维瞬态传输方程。岳子本身不在高校,没有发sci的需求。这个代码用于讲课也不太合适,因此开源。据岳子所知,在2025年2月,本代码是国际上首次开源的基于libtorch的PINN-TVD开源代码。在阅读本文之前,请先阅读:
要不然不会理解代码的含义。相应的代码如下。由于代码量过大,不对其进行介绍。建议系统学习OpenFOAM以及libtorch后再来理解。
// PINN + TVD, by dyfluid.com
// du/dt + speed*du/dx = 0
#include <torch/torch.h>
class NN
:
public torch::nn::Module
{
torch::nn::Sequential net_;
public:
NN()
{
net_ = register_module
(
"net",
torch::nn::Sequential
(
torch::nn::Linear(2,20),
torch::nn::Tanh(),
torch::nn::Linear(20,30),
torch::nn::Tanh(),
torch::nn::Linear(30,30),
torch::nn::Tanh(),
torch::nn::Linear(30,30),
torch::nn::Tanh(),
torch::nn::Linear(30,1)
)
);
}
auto forward(torch::Tensor x)
{
return net_->forward(x);
}
};
using namespace std;
int main()
{
int Iter = 10000;
auto crit = torch::nn::MSELoss();
auto model = std::make_shared<NN>();
auto opti = std::make_shared<torch::optim::Adam>(model->parameters());
// Computational domain
double dx = 0.01;
double dt = 0.02;
double length = 1.0;
double endTime = 5.0;
double UleftFixedValue = 0.5;
double speed = 0.1;
auto x = torch::arange(0, length + dx, dx);
auto t = torch::arange(0, endTime + dt, dt);
auto meshAndTime =
torch::stack(torch::meshgrid({x, t})).reshape({2, -1}).transpose(0, 1);
meshAndTime.requires_grad_(true);
cout<< "x's shape is " << x.sizes() << endl;
cout<< "t's shape is " << t.sizes() << endl;
cout<< "meshAndTime's shape is " << meshAndTime.sizes() << endl;
auto left =
torch::stack
(
torch::meshgrid({x.index({0}), t})
).reshape({2, -1}).transpose(0, 1);
auto right =
torch::stack
(
torch::meshgrid({x.index({-1}), t})
).reshape({2, -1}).transpose(0, 1);
auto ic =
torch::stack
(
torch::meshgrid({x, t.index({0})})
).reshape({2, -1}).transpose(0, 1);
auto icbc = torch::cat({right, left, ic});
auto U_right = torch::zeros(right.size(0));
auto U_left = torch::full_like(U_right, UleftFixedValue);
auto U_ic = torch::zeros((int)(length/dx) + 1);
U_ic.index({torch::indexing::Slice(0, (int)(length/dx/2.0))}) =
UleftFixedValue;
auto U_icbc = torch::cat({U_right, U_left, U_ic}).unsqueeze(1);
for (int i = 0; i < Iter; i++)
{
opti->zero_grad();
auto U_icbcPred = model->forward(icbc);
auto loss_data = crit(U_icbcPred, U_icbc);
auto UIter = model->forward(meshAndTime);
auto ytrain = torch::full({UIter.size(0)}, 0.25).unsqueeze(1);
loss_data +=
crit(torch::abs(UIter - ytrain), torch::full_like(ytrain, 0.25));
auto dudxt =
torch::autograd::grad
(
{UIter},
{meshAndTime},
{torch::ones_like(UIter)},
true,
true
)[0];
auto dudx = dudxt.index({torch::indexing::Slice(), 0});
auto dudt = dudxt.index({torch::indexing::Slice(), 1});
auto loss_pde = crit(dudt, -speed*dudx);
auto Un = UIter.reshape({(int)(length/dx) + 1, (int)(endTime/dt) + 1});
auto tvn = torch::full_like(Un.slice(0, 0, 1), 0.0);
auto tvd = torch::zeros({1});
for (int j = 0; j < (length/dx); j++)
{
auto Unfirst = Un.slice(0, j, j + 1);
auto UnSecond = Un.slice(0, j + 1, j + 2);
tvn += torch::abs(Unfirst - UnSecond);
}
for (int j = 0; j < (endTime/dt); j++)
{
auto tvnt1 = tvn.index({torch::indexing::Slice(), j});
auto tvnt2 = tvn.index({torch::indexing::Slice(), j + 1});
tvd += torch::pow(torch::clamp_min(tvnt2 - tvnt1, 0.0), 2.0);
}
auto loss_tvd = crit(tvd, 0*tvd);
auto l = loss_pde + loss_data + 10*loss_tvd;
l.backward();
opti->step();
if (i % 100 == 0)
{
cout << "Epoch: " << i << ", loss:" << l.item<double>() << endl;
auto Ufinal =
UIter.reshape({(int)(length/dx) + 1, (int)(endTime/dt) + 1});
ofstream file("U-PINN-TVD");
for (int i = 0; i < Ufinal.size(0); i++)
{
for (int j = 0; j < Ufinal.size(1); j++)
{
file << Ufinal[i][j].item<float>() << " ";
}
file << "\n";
}
file.close();
}
}
cout<< "Done!" << endl;
return 0;
}
该代码运行后,会自动输出结果文件。可以通过下面的代码来出图:
#!/bin/sh
cd ${0%/*} || exit 1 # Run from this directory
gnuplot<<EOF
set terminal pdfcairo enhanced color size 13cm,10cm font 'Verdana, 14'
set output "result.pdf"
set title "PINN VS PINN + TVD by libtorch, by dyfluid.com"
set xlabel "X-axis"
set ylabel "Y-axis"
set xrange [0.3:1]
set xtics -1,0.1,1
set key left bottom
plot 'U-PINN-V' using (\$0*0.01):(column(50)) w p ps 0.5 pt 7 lc rgb 'red' title "PINN Vanilla" ,\
'U-PINN-TVD' using (\$0*0.01):(column(50)) with lines title "PINN+TVD" lw 5 lc rgb 'red' ,\
'U-PINN-V' using (\$0*0.01):(column(100)) w p ps 0.5 pt 7 lc rgb 'black' title "PINN Vanilla" ,\
'U-PINN-TVD' using (\$0*0.01):(column(100)) with lines title "PINN+TVD" lw 5 lc rgb 'black',\
'U-PINN-V' using (\$0*0.01):(column(150)) w p ps 0.5 pt 7 lc rgb 'blue' title "PINN Vanilla" ,\
'U-PINN-TVD' using (\$0*0.01):(column(150)) with lines title "PINN+TVD" lw 5 lc rgb 'blue'
EOF
注意,生成图片的时候需要调用原生PINNPINN Vanilla
的数据,可点击下载。相应的结果如下:
