#include "LandslideTrain.h" #include LandslideTrain::LandslideTrain() { //ui.setupUi(this); } QString LandslideTrain::PannelName() { return QString::fromLocal8Bit("地质模块"); } QString LandslideTrain::CategoryName() { return QString::fromLocal8Bit("地质模块"); } QString LandslideTrain::EnglishName() { return QString::fromLocal8Bit("LandslideTrain"); } QString LandslideTrain::ChineseName() { return QString::fromLocal8Bit("模型构建"); } QString LandslideTrain::Information() { return QString::fromLocal8Bit("模型构建"); } QString LandslideTrain::IconPath() { return ":/LandslideTrain/resources/dem_vec.svg"; } QWidget* LandslideTrain::CenterWidget() { //QString gdal_path = qApp->applicationDirPath().toLocal8Bit() + "/share/gdal"; //QString pro_lib_path = qApp->applicationDirPath().toLocal8Bit() + "/share/proj"; //qputenv("GDAL_DATA", gdal_path.toLocal8Bit()); //qputenv("PROJ_LIB", pro_lib_path.toLocal8Bit()); bool showWin = false; if (myWidget == nullptr) { myWidget = new QDialog(); showWin = true; qDebug() << "new QDialog()"; } else { qDebug() << "already have myWidget"; //激活窗口并提升至顶层 myWidget->activateWindow(); myWidget->raise(); return myWidget; } ui.setupUi(myWidget); myWidget->setWindowTitle(QString::fromLocal8Bit("模型训练")); myWidget->setWindowFlags(Qt::CustomizeWindowHint | Qt::WindowCloseButtonHint); myWidget->setWindowIcon(QIcon(":/LandslideTrain/resources/dem_vec.svg")); myWidget->setAttribute(Qt::WA_QuitOnClose, false); myWidget->setAttribute(Qt::WA_DeleteOnClose); connect(myWidget, &QDialog::destroyed, this, [=] { qDebug() << "----Landslide train window close----"; QDir pluginsDir = QDir(qApp->applicationDirPath()); if (pluginsDir.cd("srsplugins\\SldModel")) { QString strConfigPath = pluginsDir.absoluteFilePath("sld_config.ini"); QFile f(strConfigPath); if (f.exists()) { WriteConfigPaths(strConfigPath); f.close(); } } if (mWorkThread != nullptr) { mWorkThread->requestInterruption(); mWorkThread->quit(); mWorkThread->wait();//调用wait后先调用finished信号对应的槽函数,执行完成后再往下走 mWorkObject->on_cancel(); mWorkThread = nullptr;//mWorkThread, &QThread::finished, mWorkThread, &QObject::deleteLater,不需要delete mWorkObject = nullptr;//mWorkThread, &QThread::finished, mWorker, &QObject::deleteLater } myWidget->close(); myWidget = nullptr; }); connect(ui.pbtInModel, &QPushButton::clicked, this, &LandslideTrain::chooseInModel); connect(ui.pbtInDataset, &QPushButton::clicked, this, &LandslideTrain::chooseInDataset); connect(ui.pbtInLabel, &QPushButton::clicked, this, &LandslideTrain::chooseInLabel); connect(ui.pbtOutResult, &QPushButton::clicked, this, &LandslideTrain::chooseResultPath); connect(ui.pushButton_ok, &QPushButton::clicked, this, &LandslideTrain::readAndStart); connect(ui.pushButton_cancel, &QPushButton::clicked, this, &LandslideTrain::pbCancel); ui.lineEpoch1->setValidator(new QIntValidator(0, 999, this)); ui.lineEpoch2->setValidator(new QIntValidator(0, 999, this)); ui.lineEpoch1->setText("100"); ui.lineEpoch2->setText("100"); ui.pbtInModel->setFocus(); connect(ui.checkBoxGenData, &QCheckBox::clicked, this, [=](bool checked) { executeGenData = checked; if (!executeGenData && !executeTrain) ui.pushButton_ok->setEnabled(false); else ui.pushButton_ok->setEnabled(true); }); connect(ui.checkBoxTrain, &QCheckBox::clicked, this, [=](bool checked) { executeTrain = checked; if (!executeGenData && !executeTrain) ui.pushButton_ok->setEnabled(false); else ui.pushButton_ok->setEnabled(true); }); ui.progressBar->setTextVisible(true); ui.progressBar->setRange(0, 100); QFile qssFile(":/LandslideTrain/LandslideTrain.qss"); qssFile.open(QFile::ReadOnly); //以只读方式打开 if (qssFile.isOpen()) { QString qss = QLatin1String(qssFile.readAll()); myWidget->setStyleSheet(qss); qssFile.close(); } else qDebug() << "-- no qssFile"; QDir pluginsDir = QDir(qApp->applicationDirPath()); if (pluginsDir.cd("srsplugins\\SldModel")) { QString strConfigPath = pluginsDir.absoluteFilePath("sld_config.ini"); QFile f(strConfigPath); if (f.exists()) { ReadConfigHistoryPaths(strConfigPath); f.close(); } } if (showWin) myWidget->show(); return myWidget; } void LandslideTrain::startWorkThread() { if (mWorkObject != nullptr) { return; } mWorkThread = new QThread(); mWorkObject = new WorkObject(); mWorkObject->moveToThread(mWorkThread); connect(mWorkThread, &QThread::finished, mWorkThread, &QObject::deleteLater); connect(mWorkThread, &QThread::finished, mWorkObject, &QObject::deleteLater); connect(mWorkObject, &WorkObject::progress, myWidget, [=](double val) {ui.progressBar->SetDoubleFormatValue(QString::fromLocal8Bit("进度"), val); }); connect(mWorkObject, &WorkObject::trainFinished, this, &LandslideTrain::finished); connect(this, &LandslideTrain::start, mWorkObject, &WorkObject::runTrainWork); mWorkThread->start(); } void LandslideTrain::ReadConfigHistoryPaths(QString strPath) { QSettings configIni(strPath, QSettings::IniFormat); //打开标题为:[SldTrain] 的组,并读取出port字段的值 configIni.beginGroup("SldTrain"); ui.lineInModel->setText(configIni.value("RetrainModel").toString()); ui.lineInDataset->setText(configIni.value("SrcDom").toString()); ui.lineInLabel->setText(configIni.value("SrcLabel").toString()); ui.lineOutResult->setText(configIni.value("TrainResult").toString()); configIni.endGroup();//关闭组 } void LandslideTrain::WriteConfigPaths(QString strPath) { QSettings configIni(strPath, QSettings::IniFormat); configIni.setIniCodec("utf-8"); //打开标题为:[SldTrain] 的组 configIni.beginGroup("SldTrain"); //更新输入模型路径 QString temp = ui.lineInModel->text(); if (temp != "") configIni.setValue("RetrainModel", temp); //更新输入DOM路径 temp = ui.lineInDataset->text(); if (temp != "") configIni.setValue("SrcDom", temp); //更新输入Label路径 temp = ui.lineInLabel->text(); if (temp != "") configIni.setValue("SrcLabel", temp); //更新输出结果路径 temp = ui.lineOutResult->text(); if (temp != "") configIni.setValue("TrainResult", temp); configIni.endGroup();//关闭组 } void LandslideTrain::readAndStart() { QString inDataset = ui.lineInDataset->text(); QString inLabel=ui.lineInLabel->text(); QString inModel=ui.lineInModel->text(); QString outResult=ui.lineOutResult->text(); if (inDataset == "" || inLabel == "" || inModel == "" || outResult == "") { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("请检查输入输出路径")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } ui.progressBar->SetDoubleFormatValue(QString::fromLocal8Bit("进度"), 0); QDir inDatasetDir(inDataset); if (!inDatasetDir.exists()) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("训练数据文件夹不存在")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } QDir inLabelDir(inLabel); if (!inLabelDir.exists()) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("标签数据文件夹不存在")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } QDir outModelDir(outResult); if (!outModelDir.exists()) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("模型输出文件夹不存在")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } if (executeGenData && executeTrain) qDebug() << "executeGenData & executeTrain"; if (executeGenData && (!executeTrain)) qDebug() << "executeGenData, not executeTrain"; if ((!executeGenData) && executeTrain) { //判断train文件夹下有无训练集文件夹 QDir dirImage(outResult + "/Images"); QDir dirLabel(outResult + "/Labels"); if (!dirImage.exists() || !dirLabel.exists()) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("未找到训练集文件夹\n确保Images和Labels在以下路径中: \n") + outResult); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } //判断文件是否匹配 QStringList imgList = getAllFiles(outResult + "/Images", "tif"); QStringList labelList = getAllFiles(outResult + "/Labels", "tif"); if (imgList.size() == 0 || labelList.size() == 0) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("训练集文件夹内为空")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } if (imgList.size() != labelList.size()) { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("训练集Images、Labels文件夹内文件不匹配")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } } if (ui.lineEpoch1->text() == "" || ui.lineEpoch2->text() == "") { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("错误"), QString::fromLocal8Bit("请输入正确的训练轮数")); mess.setWindowFlags(Qt::Drawer); int result = mess.exec(); return; } QString epoch1 = ui.lineEpoch1->text(); QString epoch2 = ui.lineEpoch2->text(); if (mWorkThread == nullptr) { qDebug() << "--startThread"; startWorkThread(); } emit start(inModel, inDataset, inLabel, outResult, executeGenData, executeTrain, epoch1, epoch2); } QStringList LandslideTrain::getAllFiles(QString path, QString fileType) { QDir dir(path); if (!dir.exists()) return QStringList(); dir.setFilter(QDir::Files | QDir::NoSymLinks); QFileInfoList list = dir.entryInfoList(); int file_count = list.count(); if (file_count <= 0) return QStringList(); QStringList files; for (int i = 0; i < file_count; i++) { QFileInfo file_info = list.at(i); QString suffix = file_info.suffix(); if (QString::compare(suffix, QString(fileType), Qt::CaseInsensitive) == 0) { QString absolute_file_path = file_info.absoluteFilePath(); files.append(absolute_file_path); } } return files; } void LandslideTrain::finished() { QMessageBox mess(QMessageBox::NoIcon, QString::fromLocal8Bit("运行结束"), QString::fromLocal8Bit("结果文件生成路径\n") + ui.lineOutResult->text(), QMessageBox::Ok, NULL); mess.setWindowFlags(Qt::Drawer); mess.setButtonText(QMessageBox::Ok, QString::fromLocal8Bit("确认")); int result = mess.exec(); pbCancel(); } void LandslideTrain::pbCancel() { //qDebug() << "--pbtCancel"; delete myWidget;//调起&QDialog::destroyed } void LandslideTrain::chooseInModel() { QString dirModel = QFileDialog::getOpenFileName(ui.pbtInModel, QString::fromLocal8Bit("选择输入初始训练模型文件"), "", "*.pth"); if (dirModel != "") ui.lineInModel->setText(dirModel); } void LandslideTrain::chooseInDataset() { QString dirDataset = QFileDialog::getExistingDirectory(ui.pbtInDataset, QString::fromLocal8Bit("选择输入训练数据路径"), ""); if (dirDataset != "") ui.lineInDataset->setText(dirDataset); } void LandslideTrain::chooseInLabel() { QString dirLabel = QFileDialog::getExistingDirectory(ui.pbtInLabel, QString::fromLocal8Bit("选择输入标签数据路径"), ""); if (dirLabel != "") ui.lineInLabel->setText(dirLabel); } void LandslideTrain::chooseResultPath() { QString dirResult = QFileDialog::getExistingDirectory(ui.pbtOutResult, QString::fromLocal8Bit("选择输出模型文件路径"), ""); if (dirResult != "") ui.lineOutResult->setText(dirResult); } void WorkObject::runTrainWork(QString inModel, QString dataset, QString label, QString outModel, bool gen, bool train, QString epoch1, QString epoch2) { QDir pluginsDir = QDir(qApp->applicationDirPath()); if (!pluginsDir.cd("models\\envs")) { qDebug() << "no folder models\\envs"; return; } QString exeDirName = pluginsDir.absoluteFilePath("train_3c_landslide.exe"); QString inDom = " --dom_path " + dataset + "/"; QString inLabel = " --label_path " + label + "/"; QString trainedModel = " --retrained_model " + inModel; QString save_model = " --save_model " + outModel + "/"; QString strMid, strTrain; if (gen) strMid = " --exe_mid True "; else strMid = " --exe_mid False "; if (train) strTrain = " --exe_train True"; else strTrain = " --exe_train False"; QString strEpoch1, strEpoch2; strEpoch1 = " --epoch1 " + epoch1; strEpoch2 = " --epoch2 " + epoch2; QString ss = exeDirName + inDom + inLabel + trainedModel + save_model + strMid + strTrain + strEpoch1 + strEpoch2; qDebug() << ss; QProcess* pProces = new QProcess(this); connect(pProces, SIGNAL(readyReadStandardOutput()), this, SLOT(on_read())); pProces->start(ss); } void WorkObject::on_read() { mProcess = (QProcess*)sender(); QString output = QString::fromLocal8Bit(mProcess->readAllStandardOutput()); if (output.toDouble() > 0) { qDebug() << "exe out:" << output.toDouble(); emit progress(output.toDouble()); if (output.toDouble() == 100.0) { delete mProcess; mProcess = nullptr; emit trainFinished(); } } else qDebug() << "Unresolved exe out:" << output; } void WorkObject::on_cancel() { if (mProcess == nullptr) { qDebug() << "--mProcess null"; } else { QString KillStr = "taskkill /f /im train_3c_landslide.exe"; mProcess->startDetached(KillStr); qDebug() << "--kill Process"; } }