diff --git a/client/platform/desktop/backend/native/common.ts b/client/platform/desktop/backend/native/common.ts index 1493aeecb..ea8bfbc0f 100644 --- a/client/platform/desktop/backend/native/common.ts +++ b/client/platform/desktop/backend/native/common.ts @@ -389,6 +389,7 @@ async function getPipelineList(settings: Settings): Promise { '^.*\\.svm', '^.*\\.lbl', '^.*\\.cfg', + '^.*\\.yaml', ].join('|')); const trainedPipelinePath = npath.join(settings.dataPath, PipelinesFolderName); const trainedExists = await fs.pathExists(trainedPipelinePath); diff --git a/client/platform/desktop/backend/native/viame.ts b/client/platform/desktop/backend/native/viame.ts index 73ea8514f..7683ec768 100644 --- a/client/platform/desktop/backend/native/viame.ts +++ b/client/platform/desktop/backend/native/viame.ts @@ -381,17 +381,30 @@ async function exportTrainedPipeline( throw new Error(isValid); } - const exportPipelinePath = npath.join(settings.viamePath, PipelineRelativeDir, 'convert_to_onnx.pipe'); + const exportPipelinePath = npath.join(settings.viamePath, PipelineRelativeDir, 'convert_model_to_onnx.pipe'); if (!fs.existsSync(npath.join(exportPipelinePath))) { throw new Error("Your VIAME version doesn't support ONNX export. You have to update it to a newer version to be able to export models."); } const modelPipelineDir = npath.parse(pipeline.pipe).dir; - let weightsPath: string; - if (fs.existsSync(npath.join(modelPipelineDir, 'yolo.weights'))) { - weightsPath = npath.join(modelPipelineDir, 'yolo.weights'); - } else { - throw new Error('Your pipeline has no trained weights (yolo.weights is missing)'); + const extensions = ['.weights', '.ckpt', '.pth']; + let weightsPath: string | undefined; + + const files = fs.readdirSync(modelPipelineDir); + + const foundExtension = extensions.find((ext) => + files.some((file) => file.toLowerCase().endsWith(ext)) + ); + + if (foundExtension) { + const fileName = files.find((file) => file.toLowerCase().endsWith(foundExtension)); + if (fileName) { + weightsPath = npath.join(modelPipelineDir, fileName); + } + } + + if (!weightsPath) { + throw new Error(`No weights path (${extensions.join(', ')}) found.`); } const jobWorkDir = await createCustomWorkingDirectory(settings, 'OnnxExport', pipeline.name); diff --git a/server/dive_tasks/tasks.py b/server/dive_tasks/tasks.py index 9b80c9712..cb78d5159 100644 --- a/server/dive_tasks/tasks.py +++ b/server/dive_tasks/tasks.py @@ -308,9 +308,20 @@ def export_trained_pipeline(self: Task, params: ExportTrainedPipelineJob): trained_pipeline_path = utils.make_directory(_working_directory_path / 'trained_pipeline') output_path = utils.make_directory(_working_directory_path / 'output') onnx_path = output_path / output_name - convert_to_onnx_pipeline_path = conf.viame_pipeline_path / "convert_to_onnx.pipe" + convert_to_onnx_pipeline_path = conf.viame_pipeline_path / "convert_model_to_onnx.pipe" gc.downloadFolderRecursive(input_folder_id, str(trained_pipeline_path)) + extensions = ['*.weights', '*.ckpt', '*.pth'] + model_file = None + + for ext in extensions: + found_files = list(trained_pipeline_path.glob(ext)) + if found_files: + model_file = found_files[0] + break + + if not model_file: + raise FileNotFoundError(f"No weights path ({extensions}) found.") # Convert pipeline to ONNX command = [ @@ -318,7 +329,7 @@ def export_trained_pipeline(self: Task, params: ExportTrainedPipelineJob): f"KWIVER_DEFAULT_LOG_LEVEL={shlex.quote(conf.kwiver_log_level)}", "kwiver runner", f"{shlex.quote(str(convert_to_onnx_pipeline_path))}", - f"-s onnx_convert:model_path={shlex.quote(str(trained_pipeline_path / 'yolo.weights'))}", + f"-s onnx_convert:model_path={shlex.quote(str(model_file))}", f"-s onnx_convert:onnx_model_prefix={shlex.quote(str(onnx_path))}" ] diff --git a/server/dive_utils/constants.py b/server/dive_utils/constants.py index f836548dc..06bd4787a 100644 --- a/server/dive_utils/constants.py +++ b/server/dive_utils/constants.py @@ -143,5 +143,5 @@ AddonsListURL = 'https://github.com/VIAME/VIAME/raw/main/cmake/download_viame_addons.csv' -TrainingModelExtensions = (".zip", ".pth", ".pt", ".py", ".weights", ".wt") +TrainingModelExtensions = (".zip", ".pth", ".pt", ".py", ".weights", ".wt", ".ckpt") MISALGINED_MARKER = "VideoMisaligned"