ML: PINN代码实例-TVD格式

本文尝试使用PINN并调用TVD条件来求解一维瞬态传输方程。岳子本身不在高校,没有发sci的需求。这个代码用于讲课也不太合适,因此开源。据岳子所知,在2025年2月,本代码是国际上首次开源的基于libtorch的PINN-TVD开源代码。在阅读本文之前,请先阅读:

要不然不会理解代码的含义。相应的代码如下。由于代码量过大,不对其进行介绍。建议系统学习OpenFOAM以及libtorch后再来理解。

// PINN + TVD, by dyfluid.com
// du/dt + speed*du/dx = 0 

#include <torch/torch.h>

class NN
:
    public torch::nn::Module 
{
    torch::nn::Sequential net_;

public:

    NN()
    {
        net_ = register_module
        (
            "net", 
            torch::nn::Sequential
            (
                torch::nn::Linear(2,20),
                torch::nn::Tanh(),
                torch::nn::Linear(20,30),
                torch::nn::Tanh(),
                torch::nn::Linear(30,30),
                torch::nn::Tanh(),
                torch::nn::Linear(30,30),
                torch::nn::Tanh(),
                torch::nn::Linear(30,1)
            )
        );
    }

    auto forward(torch::Tensor x)
    {
        return net_->forward(x);
    }
};

using namespace std;

int main()
{
    int Iter = 10000;
    auto crit = torch::nn::MSELoss();
    auto model = std::make_shared<NN>();
    auto opti = std::make_shared<torch::optim::Adam>(model->parameters());
   
    // Computational domain
    double dx = 0.01;
    double dt = 0.02;
    double length = 1.0;
    double endTime = 5.0;
    double UleftFixedValue = 0.5;
    double speed = 0.1;

    auto x = torch::arange(0, length + dx, dx);
    auto t = torch::arange(0, endTime + dt, dt);
    auto meshAndTime = 
        torch::stack(torch::meshgrid({x, t})).reshape({2, -1}).transpose(0, 1);
    meshAndTime.requires_grad_(true);
    cout<< "x's shape is " << x.sizes() << endl;
    cout<< "t's shape is " << t.sizes() << endl;
    cout<< "meshAndTime's shape is " << meshAndTime.sizes() << endl;

    auto left = 
        torch::stack
        (
            torch::meshgrid({x.index({0}), t})
        ).reshape({2, -1}).transpose(0, 1);

    auto right = 
        torch::stack
        (
            torch::meshgrid({x.index({-1}), t})
        ).reshape({2, -1}).transpose(0, 1);

    auto ic = 
        torch::stack
        (
            torch::meshgrid({x, t.index({0})})
        ).reshape({2, -1}).transpose(0, 1);

    auto icbc = torch::cat({right, left, ic});

    auto U_right = torch::zeros(right.size(0));
    auto U_left = torch::full_like(U_right, UleftFixedValue);
    auto U_ic = torch::zeros((int)(length/dx) + 1);
    U_ic.index({torch::indexing::Slice(0, (int)(length/dx/2.0))}) = 
        UleftFixedValue;
    auto U_icbc = torch::cat({U_right, U_left, U_ic}).unsqueeze(1);

    for (int i = 0; i < Iter; i++) 
    {
        opti->zero_grad();

        auto U_icbcPred = model->forward(icbc);
        auto loss_data = crit(U_icbcPred, U_icbc);
        auto UIter = model->forward(meshAndTime);
        
        auto ytrain = torch::full({UIter.size(0)}, 0.25).unsqueeze(1);
        loss_data += 
            crit(torch::abs(UIter - ytrain), torch::full_like(ytrain, 0.25));

        auto dudxt = 
            torch::autograd::grad
            (
                {UIter},
                {meshAndTime},
                {torch::ones_like(UIter)},
                true,
                true
            )[0];

        auto dudx = dudxt.index({torch::indexing::Slice(), 0});
        auto dudt = dudxt.index({torch::indexing::Slice(), 1});

        auto loss_pde = crit(dudt, -speed*dudx);

        auto Un = UIter.reshape({(int)(length/dx) + 1, (int)(endTime/dt) + 1});
        auto tvn = torch::full_like(Un.slice(0, 0, 1), 0.0);
        auto tvd = torch::zeros({1});
        for (int j = 0; j < (length/dx); j++)
        {
            auto Unfirst = Un.slice(0, j, j + 1);
            auto UnSecond = Un.slice(0, j + 1, j + 2);
            tvn += torch::abs(Unfirst - UnSecond);
        }
        for (int j = 0; j < (endTime/dt); j++)
        {
            auto tvnt1 = tvn.index({torch::indexing::Slice(), j});
            auto tvnt2 = tvn.index({torch::indexing::Slice(), j + 1});
            tvd += torch::pow(torch::clamp_min(tvnt2 - tvnt1, 0.0), 2.0);
        }
        auto loss_tvd = crit(tvd, 0*tvd);

        auto l = loss_pde + loss_data + 10*loss_tvd;

        l.backward();
        opti->step();

        if (i % 100 == 0) 
        {
            cout << "Epoch: " << i << ", loss:" << l.item<double>() << endl;

            auto Ufinal = 
                UIter.reshape({(int)(length/dx) + 1, (int)(endTime/dt) + 1});
            ofstream file("U-PINN-TVD");
            for (int i = 0; i < Ufinal.size(0); i++) 
            {
                for (int j = 0; j < Ufinal.size(1); j++) 
                {
                    file << Ufinal[i][j].item<float>() << " ";
                }
                file << "\n";
            }
            file.close();
        }
    }

    cout<< "Done!" << endl;
    return 0;
}

该代码运行后,会自动输出结果文件。可以通过下面的代码来出图:

#!/bin/sh
cd ${0%/*} || exit 1    # Run from this directory

gnuplot<<EOF 
set terminal pdfcairo enhanced color size 13cm,10cm font 'Verdana, 14'
set output "result.pdf"

set title "PINN VS PINN + TVD by libtorch, by dyfluid.com"
set xlabel "X-axis"
set ylabel "Y-axis"

set xrange [0.3:1]
set xtics -1,0.1,1
set key left bottom

plot 'U-PINN-V' using (\$0*0.01):(column(50)) w p ps 0.5 pt 7 lc rgb 'red' title "PINN Vanilla"  ,\
     'U-PINN-TVD' using (\$0*0.01):(column(50)) with lines  title "PINN+TVD" lw 5 lc rgb 'red' ,\
     'U-PINN-V' using (\$0*0.01):(column(100)) w p ps 0.5 pt 7 lc rgb 'black' title "PINN Vanilla" ,\
     'U-PINN-TVD' using (\$0*0.01):(column(100)) with lines title "PINN+TVD" lw 5 lc rgb 'black',\
     'U-PINN-V' using (\$0*0.01):(column(150)) w p ps 0.5 pt 7 lc rgb 'blue' title "PINN Vanilla" ,\
     'U-PINN-TVD' using (\$0*0.01):(column(150)) with lines title "PINN+TVD" lw 5 lc rgb 'blue'

EOF

注意,生成图片的时候需要调用原生PINNPINN Vanilla的数据,可点击下载。相应的结果如下:

_images/pinntvd.png