diff --git a/cv_utils.cpp b/cv_utils.cpp index 11b0191d..bda3349f 100644 --- a/cv_utils.cpp +++ b/cv_utils.cpp @@ -13,6 +13,15 @@ cv::Mat imreadRGB(const std::string &filename){ return cImg; } +cv::Mat imreadMask(const std::string &filename){ + cv::Mat mask = cv::imread(filename, cv::IMREAD_GRAYSCALE); + if (mask.empty()){ + std::cerr << "Cannot read mask " << filename << std::endl; + exit(1); + } + return mask; +} + void imwriteRGB(const std::string &filename, const cv::Mat &image){ cv::Mat rgb; cv::cvtColor(image, rgb, cv::COLOR_RGB2BGR); @@ -48,3 +57,22 @@ torch::Tensor imageToTensor(const cv::Mat &image){ return (img.toType(torch::kFloat32) / 255.0f); } +torch::Tensor maskToTensor(const cv::Mat &mask){ + torch::Tensor m = torch::from_blob(mask.data, { mask.rows, mask.cols, 1 }, torch::kU8); + + // Binary mask: threshold at 128, output 0.0 or 1.0 + return (m.toType(torch::kFloat32) / 255.0f).ge(0.5f).toType(torch::kFloat32); +} + +cv::Mat tensorToMask(const torch::Tensor &t){ + int h = t.sizes()[0]; + int w = t.sizes()[1]; + + cv::Mat mask(h, w, CV_8UC1); + torch::Tensor scaledTensor = (t.squeeze() * 255.0).toType(torch::kU8); + uint8_t* dataPtr = static_cast(scaledTensor.data_ptr()); + std::copy(dataPtr, dataPtr + (w * h), mask.data); + + return mask; +} + diff --git a/cv_utils.hpp b/cv_utils.hpp index 87c7c2ad..c107ae6d 100644 --- a/cv_utils.hpp +++ b/cv_utils.hpp @@ -7,10 +7,13 @@ #include cv::Mat imreadRGB(const std::string &filename); +cv::Mat imreadMask(const std::string &filename); void imwriteRGB(const std::string &filename, const cv::Mat &image); cv::Mat floatNxNtensorToMat(const torch::Tensor &t); torch::Tensor floatNxNMatToTensor(const cv::Mat &m); cv::Mat tensorToImage(const torch::Tensor &t); torch::Tensor imageToTensor(const cv::Mat &image); +torch::Tensor maskToTensor(const cv::Mat &mask); +cv::Mat tensorToMask(const torch::Tensor &t); #endif \ No newline at end of file diff --git a/input_data.cpp b/input_data.cpp index 53daebd9..7af56ef1 100644 --- a/input_data.cpp +++ b/input_data.cpp @@ -94,6 +94,21 @@ void Camera::loadImage(float downscaleFactor){ fy = K[1][1].item(); cx = K[0][2].item(); cy = K[1][2].item(); + + // Load mask if path is set + if (!maskPath.empty()){ + + std::cout << "Loading mask " << maskPath << std::endl; + + cv::Mat cMask = imreadMask(maskPath); + + // Resize mask to match image dimensions + if (cMask.rows != height || cMask.cols != width){ + cv::resize(cMask, cMask, cv::Size(width, height), 0.0, 0.0, cv::INTER_NEAREST); + } + + mask = maskToTensor(cMask); + } } torch::Tensor Camera::getImage(int downscaleFactor){ @@ -116,6 +131,27 @@ torch::Tensor Camera::getImage(int downscaleFactor){ } } +torch::Tensor Camera::getMask(int downscaleFactor){ + if (!mask.numel()) return torch::Tensor(); // No mask available + + if (downscaleFactor <= 1) return mask; + + if (maskPyramids.find(downscaleFactor) != maskPyramids.end()){ + return maskPyramids[downscaleFactor]; + } + + // Rescale using nearest neighbor (preserve binary values) + cv::Mat cMask = tensorToMask(mask); + cv::resize(cMask, cMask, cv::Size(cMask.cols / downscaleFactor, cMask.rows / downscaleFactor), 0.0, 0.0, cv::INTER_NEAREST); + torch::Tensor t = maskToTensor(cMask); + maskPyramids[downscaleFactor] = t; + return t; +} + +bool Camera::hasMask() const { + return mask.numel() > 0; +} + bool Camera::hasDistortionParameters(){ return k1 != 0.0f || k2 != 0.0f || k3 != 0.0f || p1 != 0.0f || p2 != 0.0f; } diff --git a/input_data.hpp b/input_data.hpp index a05442be..6ddd7ef8 100644 --- a/input_data.hpp +++ b/input_data.hpp @@ -24,6 +24,7 @@ struct Camera{ float p2 = 0; torch::Tensor camToWorld; std::string filePath = ""; + std::string maskPath = ""; // Optional path to mask file CameraType cameraType = CameraType::Perspective; Camera(){}; @@ -37,12 +38,16 @@ struct Camera{ bool hasDistortionParameters(); std::vector undistortionParameters(); torch::Tensor getImage(int downscaleFactor); + torch::Tensor getMask(int downscaleFactor); + bool hasMask() const; void loadImage(float downscaleFactor); torch::Tensor K; torch::Tensor image; + torch::Tensor mask; // Optional mask tensor [H,W,1], 0=exclude, 1=include std::unordered_map imagePyramids; + std::unordered_map maskPyramids; // Cached downscaled masks }; struct Points{ diff --git a/model.cpp b/model.cpp index a88afa7d..0579de13 100644 --- a/model.cpp +++ b/model.cpp @@ -51,8 +51,31 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt){ return (10.f * torch::log10(1.0 / mse)); } -torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){ - return torch::abs(gt - rendered).mean(); +//define these somewhere +float __mask_opacity_penalty_power = 2.0, + __mask_opacity_penalty_weight = 1.0; + +torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& alpha, const torch::Tensor& mask){ + torch::Tensor diff = torch::abs(gt - rendered), + maskPenalty = torch::zeros_like(diff); + + + if (mask.numel() > 0){ + torch::Tensor expandedMask = mask.expand_as(diff); + + //maskPenalty = torch::zeros_like(expandedMask) + + const torch::Tensor bg_mask = 1.0 - expandedMask, + penalty_weights = bg_mask.pow(__mask_opacity_penalty_power), + penalty = (expandedMask * penalty_weights).mean() * __mask_opacity_penalty_weight; // de penalty wordt bepaald aan de hand penalty gewicht * penalty_weights + + const float inv_pixels = __mask_opacity_penalty_weight / static_cast(expandedMask.numel()); //deel de opacity weight door het aantal elementen in de alpha lijst + maskPenalty = penalty_weights * inv_pixels; //Grad alpha wordt dan de 2d penalty kaart teruggebracht tot een + //loss = loss + penalty; + + // hier moet ik alpha in krijgen zodat ik met de inverse daarvan hier vervolgens het gradient mee kan maken + } + return (diff + maskPenalty).mean(); } void Model::setupOptimizers(){ @@ -80,7 +103,7 @@ void Model::releaseOptimizers(){ } -torch::Tensor Model::forward(Camera& cam, int step){ +tensor_list Model::forward(Camera& cam, int step){ const float scaleFactor = getDownscaleFactor(step); const float fx = cam.fx / scaleFactor; @@ -118,7 +141,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor numTilesHit; // GPU-only torch::Tensor cov2d; // CPU-only torch::Tensor camDepths; // CPU-only - torch::Tensor rgb; + torch::Tensor rgb, alpha; if (device == torch::kCPU){ auto p = ProjectGaussiansCPU::apply(means, @@ -171,7 +194,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ xys.retain_grad(); if (radii.sum().item() == 0.0f) - return backgroundColor.repeat({height, width, 1}); + return { backgroundColor.repeat({height, width, 1}), torch::zeros({height, width, 1}) }; torch::Tensor viewDirs = means.detach() - T.transpose(0, 1).to(device); viewDirs = viewDirs / viewDirs.norm(2, {-1}, true); @@ -192,7 +215,7 @@ torch::Tensor Model::forward(Camera& cam, int step){ rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f); if (device == torch::kCPU){ - rgb = RasterizeGaussiansCPU::apply( + auto rgba = RasterizeGaussiansCPU::apply( xys, radii, conics, @@ -203,9 +226,11 @@ torch::Tensor Model::forward(Camera& cam, int step){ height, width, backgroundColor); + rgb = rgba[0]; + alpha = rgba[1]; }else{ #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) - rgb = RasterizeGaussians::apply( + auto rgba = RasterizeGaussians::apply( xys, depths, radii, @@ -216,12 +241,15 @@ torch::Tensor Model::forward(Camera& cam, int step){ height, width, backgroundColor); + rgb = rgba[0]; + alpha = rgba[1]; #endif } rgb = torch::clamp_max(rgb, 1.0f); + alpha = torch::clamp_max(alpha, 1.0f); - return rgb; + return { rgb, alpha }; } void Model::optimizersZeroGrad(){ @@ -777,8 +805,8 @@ int Model::loadPly(const std::string &filename){ throw std::runtime_error("Invalid PLY file"); } -torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor >, float ssimWeight){ - torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt); - torch::Tensor l1Loss = l1(rgb, gt); +torch::Tensor Model::mainLoss(torch::Tensor &rgb, const torch::Tensor& alpha, torch::Tensor >, float ssimWeight, const torch::Tensor &mask){ + torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt, mask); + torch::Tensor l1Loss = l1(rgb, gt, alpha, mask); return (1.0f - ssimWeight) * l1Loss + ssimWeight * ssimLoss; } diff --git a/model.hpp b/model.hpp index fbf3edcc..cb1e9ad1 100644 --- a/model.hpp +++ b/model.hpp @@ -17,7 +17,7 @@ using namespace torch::autograd; torch::Tensor randomQuatTensor(long long n); torch::Tensor projectionMatrix(float zNear, float zFar, float fovX, float fovY, const torch::Device &device); torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt); -torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt); +torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& alpha, const torch::Tensor& mask = torch::Tensor()); struct Model{ Model(const InputData &inputData, int numCameras, @@ -63,7 +63,7 @@ struct Model{ void setupOptimizers(); void releaseOptimizers(); - torch::Tensor forward(Camera& cam, int step); + tensor_list forward(Camera& cam, int step); void optimizersZeroGrad(); void optimizersStep(); void schedulersStep(int step); @@ -74,7 +74,7 @@ struct Model{ void saveSplat(const std::string &filename); void saveDebugPly(const std::string &filename, int step); int loadPly(const std::string &filename); - torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor >, float ssimWeight); + torch::Tensor mainLoss(torch::Tensor &rgb, const torch::Tensor& alpha, torch::Tensor >, float ssimWeight, const torch::Tensor &mask = torch::Tensor()); void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples); void removeFromOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &deletedMask); diff --git a/opensplat.cpp b/opensplat.cpp index b2826d13..0f925c73 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -42,6 +42,7 @@ int main(int argc, char *argv[]){ ("stop-screen-size-at", "Stop splitting gaussians that are larger than [split-screen-size] after these many steps", cxxopts::value()->default_value("4000")) ("split-screen-size", "Split gaussians that are larger than this percentage of screen space", cxxopts::value()->default_value("0.05")) ("colmap-image-path", "Override the default image path for COLMAP-based input", cxxopts::value()->default_value("")) + ("mask-dir", "Path to directory containing mask images (binary: 0=exclude, 1=include)", cxxopts::value()->default_value("")) #ifdef USE_VISUALIZATION ("has-visualization", "Show the visualization steps of training", cxxopts::value()->default_value("0")) #endif @@ -95,6 +96,7 @@ int main(int argc, char *argv[]){ const int stopScreenSizeAt = result["stop-screen-size-at"].as(); const float splitScreenSize = result["split-screen-size"].as(); const std::string colmapImageSourcePath = result["colmap-image-path"].as(); + const std::string maskDir = result["mask-dir"].as(); #ifdef USE_VISUALIZATION const bool hasVisualization = result["has-visualization"].as(); #endif @@ -121,6 +123,35 @@ int main(int argc, char *argv[]){ try{ InputData inputData = inputDataFromX(projectRoot, colmapImageSourcePath); + // Set mask paths if mask directory is provided + if (!maskDir.empty()){ + fs::path maskDirPath(maskDir); + if (!fs::exists(maskDirPath)){ + std::cerr << "Mask directory does not exist: " << maskDir << std::endl; + exit(1); + } + + for (Camera &cam : inputData.cameras){ + fs::path imagePath(cam.filePath); + std::string imageName = imagePath.filename().string(), + imageStem = imagePath.stem().string(); + + // Try common mask extensions + for(const std::string & name : { imageName, imageStem}) + for (const std::string &ext : {".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"}){ + fs::path maskPath = maskDirPath / (name + ext); + if (fs::exists(maskPath)){ + cam.maskPath = maskPath.string(); + break; + } + } + + if (cam.maskPath.empty()){ + std::cerr << "Warning: No mask found for " << cam.filePath << std::endl; + } + } + } + parallel_for(inputData.cameras.begin(), inputData.cameras.end(), [&downScaleFactor](Camera &cam){ cam.loadImage(downScaleFactor); }); @@ -153,11 +184,20 @@ int main(int argc, char *argv[]){ model.optimizersZeroGrad(); - torch::Tensor rgb = model.forward(cam, step); - torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step)); - gt = gt.to(device); + tensor_list rgba = model.forward(cam, step); + torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step)), + rgb = rgba[0], + alpha = rgba[1]; + + gt = gt.to(device); + + torch::Tensor mask; + if (cam.hasMask()){ + mask = cam.getMask(model.getDownscaleFactor(step)); + mask = mask.to(device); + } - torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight); + torch::Tensor mainLoss = model.mainLoss(rgb, gt, alpha, ssimWeight, mask); mainLoss.backward(); if (step % displayStep == 0) { @@ -175,7 +215,9 @@ int main(int argc, char *argv[]){ } if (!valRender.empty() && step % 10 == 0){ - torch::Tensor rgb = model.forward(*valCam, step); + tensor_list rgba = model.forward(*valCam, step); + torch::Tensor rgb = rgba[0], + alpha = rgba[1]; cv::Mat image = tensorToImage(rgb.detach().cpu()); cv::cvtColor(image, image, cv::COLOR_RGB2BGR); cv::imwrite((fs::path(valRender) / (std::to_string(step) + ".png")).string(), image); @@ -201,9 +243,15 @@ int main(int argc, char *argv[]){ // Validate if (valCam != nullptr){ - torch::Tensor rgb = model.forward(*valCam, numIters); - torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device); - std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, gt, ssimWeight).item() << std::endl; + tensor_list rgba = model.forward(*valCam, numIters); + torch::Tensor rgb = rgba[0], + alpha = rgba[1]; + torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device); + torch::Tensor valMask; + if (valCam->hasMask()){ + valMask = valCam->getMask(model.getDownscaleFactor(numIters)).to(device); + } + std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, alpha, gt, ssimWeight, valMask).item() << std::endl; } }catch(const std::exception &e){ std::cerr << e.what() << std::endl; diff --git a/rasterize_gaussians.cpp b/rasterize_gaussians.cpp index 736d4c02..c1e1a655 100644 --- a/rasterize_gaussians.cpp +++ b/rasterize_gaussians.cpp @@ -36,7 +36,7 @@ std::tuple(t); - - torch::Tensor finalTs = std::get<1>(t); + torch::Tensor outAlpha = std::get<1>(t); + torch::Tensor finalTs = std::get<2>(t); // Map of tile bin IDs - torch::Tensor finalIdx = std::get<2>(t); + torch::Tensor finalIdx = std::get<3>(t); ctx->saved_data["imgWidth"] = imgWidth; ctx->saved_data["imgHeight"] = imgHeight; ctx->save_for_backward({ gaussianIdsSorted, tileBins, xys, conics, colors, opacity, background, finalTs, finalIdx }); - return outImg; + return { outImg, outAlpha }; } tensor_list RasterizeGaussians::backward(AutogradContext *ctx, tensor_list grad_outputs) { torch::Tensor v_outImg = grad_outputs[0]; + torch::Tensor v_outAlpha = grad_outputs[1]; int imgHeight = ctx->saved_data["imgHeight"].toInt(); int imgWidth = ctx->saved_data["imgWidth"].toInt(); @@ -105,8 +106,6 @@ tensor_list RasterizeGaussians::backward(AutogradContext *ctx, tensor_list grad_ torch::Tensor finalTs = saved[7]; torch::Tensor finalIdx = saved[8]; - torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); - auto t = rasterize_backward_tensor(imgHeight, imgWidth, gaussianIdsSorted, tileBins, @@ -141,7 +140,7 @@ tensor_list RasterizeGaussians::backward(AutogradContext *ctx, tensor_list grad_ #endif -torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, +tensor_list RasterizeGaussiansCPU::forward(AutogradContext *ctx, torch::Tensor xys, torch::Tensor radii, torch::Tensor conics, @@ -167,20 +166,21 @@ torch::Tensor RasterizeGaussiansCPU::forward(AutogradContext *ctx, ); // Final image torch::Tensor outImg = std::get<0>(t); - - torch::Tensor finalTs = std::get<1>(t); - std::vector *px2gid = std::get<2>(t); + torch::Tensor outAlpha = std::get<1>(t); + torch::Tensor finalTs = std::get<2>(t); + std::vector *px2gid = std::get<3>(t); ctx->saved_data["px2gid"] = reinterpret_cast(px2gid); ctx->saved_data["imgWidth"] = imgWidth; ctx->saved_data["imgHeight"] = imgHeight; ctx->save_for_backward({ xys, conics, colors, opacity, background, cov2d, camDepths, finalTs }); - return outImg; + return { outImg, outAlpha }; } tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list grad_outputs) { torch::Tensor v_outImg = grad_outputs[0]; + torch::Tensor v_outAlpha = grad_outputs[1]; int imgHeight = ctx->saved_data["imgHeight"].toInt(); int imgWidth = ctx->saved_data["imgWidth"].toInt(); const std::vector *px2gid = reinterpret_cast *>(ctx->saved_data["px2gid"].toInt()); @@ -195,8 +195,6 @@ tensor_list RasterizeGaussiansCPU::backward(AutogradContext *ctx, tensor_list gr torch::Tensor camDepths = saved[6]; torch::Tensor finalTs = saved[7]; - torch::Tensor v_outAlpha = torch::zeros_like(v_outImg.index({"...", 0})); - auto t = rasterize_backward_tensor_cpu(imgHeight, imgWidth, xys, conics, diff --git a/rasterize_gaussians.hpp b/rasterize_gaussians.hpp index a860cf2e..c8b7b5ee 100644 --- a/rasterize_gaussians.hpp +++ b/rasterize_gaussians.hpp @@ -22,7 +22,7 @@ std::tuple{ public: - static torch::Tensor forward(AutogradContext *ctx, + static tensor_list forward(AutogradContext *ctx, torch::Tensor xys, torch::Tensor depths, torch::Tensor radii, @@ -40,7 +40,7 @@ class RasterizeGaussians : public Function{ class RasterizeGaussiansCPU : public Function{ public: - static torch::Tensor forward(AutogradContext *ctx, + static tensor_list forward(AutogradContext *ctx, torch::Tensor xys, torch::Tensor radii, torch::Tensor conics, diff --git a/rasterizer/gsplat-cpu/bindings.h b/rasterizer/gsplat-cpu/bindings.h index bfcce484..0ffb6233 100644 --- a/rasterizer/gsplat-cpu/bindings.h +++ b/rasterizer/gsplat-cpu/bindings.h @@ -34,8 +34,9 @@ project_gaussians_forward_tensor_cpu( ); std::tuple< - torch::Tensor, - torch::Tensor, + torch::Tensor, //img + torch::Tensor, //alpha + torch::Tensor, //finalTs std::vector * > rasterize_forward_tensor_cpu( const int width, diff --git a/rasterizer/gsplat-cpu/gsplat_cpu.cpp b/rasterizer/gsplat-cpu/gsplat_cpu.cpp index 63102883..a32cb787 100644 --- a/rasterizer/gsplat-cpu/gsplat_cpu.cpp +++ b/rasterizer/gsplat-cpu/gsplat_cpu.cpp @@ -131,6 +131,7 @@ project_gaussians_forward_tensor_cpu( } std::tuple< + torch::Tensor, torch::Tensor, torch::Tensor, std::vector * @@ -161,6 +162,7 @@ std::tuple< torch::Device device = xys.device(); torch::Tensor outImg = torch::zeros({height, width, channels}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor outAlpha = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); torch::Tensor finalTs = torch::ones({height, width}, torch::TensorOptions().dtype(torch::kFloat32).device(device)); torch::Tensor done = torch::zeros({height, width}, torch::TensorOptions().dtype(torch::kBool).device(device)); @@ -174,6 +176,7 @@ std::tuple< float *pOpacities = static_cast(opacities.data_ptr()); float *pOutImg = static_cast(outImg.data_ptr()); + float *pOutAlpha = static_cast(outAlpha.data_ptr()); float *pFinalTs = static_cast(finalTs.data_ptr()); bool *pDone = static_cast(done.data_ptr()); @@ -232,7 +235,7 @@ std::tuple< pOutImg[pixIdx * 3 + 0] += vis * pColors[gaussianId * 3 + 0]; pOutImg[pixIdx * 3 + 1] += vis * pColors[gaussianId * 3 + 1]; pOutImg[pixIdx * 3 + 2] += vis * pColors[gaussianId * 3 + 2]; - + pOutAlpha[pixIdx] += vis; pFinalTs[pixIdx] = nextT; px2gid[pixIdx].push_back(gaussianId); } @@ -253,7 +256,7 @@ std::tuple< } } - return std::make_tuple(outImg, finalTs, px2gid); + return std::make_tuple(outImg, outAlpha, finalTs, px2gid); } diff --git a/rasterizer/gsplat-metal/gsplat_metal.metal b/rasterizer/gsplat-metal/gsplat_metal.metal index 2313a239..75b37841 100644 --- a/rasterizer/gsplat-metal/gsplat_metal.metal +++ b/rasterizer/gsplat-metal/gsplat_metal.metal @@ -445,6 +445,7 @@ kernel void nd_rasterize_forward_kernel( device float* final_Ts, device int* final_index, device float* out_img, + device float* out_alpha, constant float* background, constant uint2& blockDim, uint2 blockIdx [[threadgroup_position_in_grid]], @@ -504,6 +505,7 @@ kernel void nd_rasterize_forward_kernel( const float vis = alpha * T; for (int c = 0; c < channels; ++c) { out_img[channels * pix_id + c] += colors[channels * g + c] * vis; + out_alpha[pix_id] += vis; } T = next_T; } diff --git a/rasterizer/gsplat-metal/gsplat_metal.mm b/rasterizer/gsplat-metal/gsplat_metal.mm index 29849e78..34d6c829 100644 --- a/rasterizer/gsplat-metal/gsplat_metal.mm +++ b/rasterizer/gsplat-metal/gsplat_metal.mm @@ -577,6 +577,9 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::Tensor out_img = torch::zeros( {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) ); + torch::Tensor out_alpha = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); torch::Tensor final_Ts = torch::zeros( {img_height, img_width}, xys.options().dtype(torch::kFloat32) ); @@ -609,11 +612,12 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(final_Ts), EncodeArg::tensor(final_idx), EncodeArg::tensor(out_img), + EncodeArg::tensor(out_alpha), EncodeArg::tensor(background), EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) }); - return std::make_tuple(out_img, final_Ts, final_idx); + return std::make_tuple(out_img, out_alpha, final_Ts, final_idx); } std::tuple< @@ -648,6 +652,9 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize torch::Tensor out_img = torch::zeros( {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) ); + torch::Tensor out_alpha = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); torch::Tensor final_Ts = torch::zeros( {img_height, img_width}, xys.options().dtype(torch::kFloat32) ); @@ -680,11 +687,12 @@ void dispatchKernel(MetalContext* ctx, id cpso, MTLSize EncodeArg::tensor(final_Ts), EncodeArg::tensor(final_idx), EncodeArg::tensor(out_img), + EncodeArg::tensor(out_alpha), EncodeArg::tensor(background), EncodeArg::array(block_size_dim2, sizeof(block_size_dim2)) }); - return std::make_tuple(out_img, final_Ts, final_idx); + return std::make_tuple(out_img, out_alpha, final_Ts, final_idx); } diff --git a/rasterizer/gsplat/backward.cu b/rasterizer/gsplat/backward.cu index c338dba2..6a1a9282 100644 --- a/rasterizer/gsplat/backward.cu +++ b/rasterizer/gsplat/backward.cu @@ -107,7 +107,8 @@ __global__ void nd_rasterize_backward_kernel( // update the running sum S[c] += rgbs[channels * g + c] * fac; } - v_alpha += T_final * ra * v_out_alpha; + //This used to be a noop cause it was zeros? //v_alpha += T_final * ra * v_out_alpha; + // update v_opacity for this gaussian atomicAdd(&(v_opacity[g]), vis * v_alpha); @@ -310,7 +311,7 @@ __global__ void rasterize_backward_kernel( v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y; v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z; - v_alpha += T_final * ra * v_out_alpha; + //Also a noop ? v_alpha += T_final * ra * v_out_alpha; // contribution from background pixel v_alpha += -T_final * ra * background.x * v_out.x; v_alpha += -T_final * ra * background.y * v_out.y; diff --git a/rasterizer/gsplat/bindings.cu b/rasterizer/gsplat/bindings.cu index 80d581ce..cf95f5de 100644 --- a/rasterizer/gsplat/bindings.cu +++ b/rasterizer/gsplat/bindings.cu @@ -377,6 +377,9 @@ rasterize_forward_tensor( torch::Tensor out_img = torch::zeros( {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) ); + torch::Tensor out_alpha = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); torch::Tensor final_Ts = torch::zeros( {img_height, img_width}, xys.options().dtype(torch::kFloat32) ); @@ -396,10 +399,11 @@ rasterize_forward_tensor( final_Ts.contiguous().data_ptr(), final_idx.contiguous().data_ptr(), (float3 *)out_img.contiguous().data_ptr(), + (float *)out_alpha.contiguous().data_ptr(), *(float3 *)background.contiguous().data_ptr() ); - return std::make_tuple(out_img, final_Ts, final_idx); + return std::make_tuple(out_img, out_alpha, final_Ts, final_idx); } @@ -446,6 +450,9 @@ nd_rasterize_forward_tensor( torch::Tensor out_img = torch::zeros( {img_height, img_width, channels}, xys.options().dtype(torch::kFloat32) ); + torch::Tensor out_alpha = torch::zeros( + {img_height, img_width}, xys.options().dtype(torch::kFloat32) + ); torch::Tensor final_Ts = torch::zeros( {img_height, img_width}, xys.options().dtype(torch::kFloat32) ); @@ -466,10 +473,11 @@ nd_rasterize_forward_tensor( final_Ts.contiguous().data_ptr(), final_idx.contiguous().data_ptr(), out_img.contiguous().data_ptr(), + out_alpha.contiguous().data_ptr(), background.contiguous().data_ptr() ); - return std::make_tuple(out_img, final_Ts, final_idx); + return std::make_tuple(out_img, out_alpha, final_Ts, final_idx); } diff --git a/rasterizer/gsplat/forward.cu b/rasterizer/gsplat/forward.cu index c831e766..d7120dbd 100644 --- a/rasterizer/gsplat/forward.cu +++ b/rasterizer/gsplat/forward.cu @@ -184,6 +184,7 @@ __global__ void nd_rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float* __restrict__ out_img, + float* __restrict__ out_alpha, const float* __restrict__ background ) { // current naive implementation where tile data loading is redundant @@ -240,6 +241,7 @@ __global__ void nd_rasterize_forward( const float vis = alpha * T; for (int c = 0; c < channels; ++c) { out_img[channels * pix_id + c] += colors[channels * g + c] * vis; + out_alpha[pix_id] += vis; } T = next_T; } @@ -265,6 +267,7 @@ __global__ void rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float3* __restrict__ out_img, + float* __restrict__ out_alpha, const float3& __restrict__ background ) { // each thread draws one pixel, but also timeshares caching gaussians in a @@ -306,7 +309,8 @@ __global__ void rasterize_forward( // each thread loads one gaussian at a time before rasterizing its // designated pixel int tr = block.thread_rank(); - float3 pix_out = {0.f, 0.f, 0.f}; + float3 pix_out = {0.f, 0.f, 0.f}; + float alpha_out = 0.f; for (int b = 0; b < num_batches; ++b) { // resync all threads before beginning next batch // end early if entire tile is done @@ -359,6 +363,7 @@ __global__ void rasterize_forward( pix_out.x = pix_out.x + c.x * vis; pix_out.y = pix_out.y + c.y * vis; pix_out.z = pix_out.z + c.z * vis; + alpha_out = alpha_out + vis; T = next_T; cur_idx = batch_start + t; } @@ -367,13 +372,13 @@ __global__ void rasterize_forward( if (inside) { // add background final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel - final_index[pix_id] = - cur_idx; // index of in bin of last gaussian in this pixel + final_index[pix_id] = cur_idx; // index of in bin of last gaussian in this pixel float3 final_color; final_color.x = pix_out.x + T * background.x; final_color.y = pix_out.y + T * background.y; final_color.z = pix_out.z + T * background.z; out_img[pix_id] = final_color; + out_alpha[pix_id] = alpha_out; // I guess the background isn't really interesting for the alpha } } diff --git a/rasterizer/gsplat/forward.cuh b/rasterizer/gsplat/forward.cuh index c342d23e..bc55feb1 100644 --- a/rasterizer/gsplat/forward.cuh +++ b/rasterizer/gsplat/forward.cuh @@ -41,6 +41,7 @@ __global__ void rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float3* __restrict__ out_img, + float* __restrict__ out_alpha, const float3& __restrict__ background ); @@ -58,6 +59,7 @@ __global__ void nd_rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float* __restrict__ out_img, + float* __restrict__ out_alpha, const float* __restrict__ background ); @@ -104,6 +106,7 @@ __global__ void rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float3* __restrict__ out_img, + float* __restrict__ out_alpha, const float3& __restrict__ background ); @@ -120,5 +123,6 @@ __global__ void nd_rasterize_forward( float* __restrict__ final_Ts, int* __restrict__ final_index, float* __restrict__ out_img, + float* __restrict__ out_alpha, const float* __restrict__ background ); \ No newline at end of file diff --git a/simple_trainer.cpp b/simple_trainer.cpp index 9912c5d6..8f82a464 100644 --- a/simple_trainer.cpp +++ b/simple_trainer.cpp @@ -145,7 +145,7 @@ int main(int argc, char **argv){ torch::optim::Adam optimizer({rgbs, means, scales, opacities, quats}, learningRate); torch::nn::MSELoss mseLoss; - torch::Tensor outImg; + torch::Tensor outImg, outAlpha; for (size_t i = 0; i < iterations; i++){ if (device == torch::kCPU){ @@ -157,7 +157,7 @@ int main(int argc, char **argv){ height, width); - outImg = RasterizeGaussiansCPU::apply( + tensor_list rgba = RasterizeGaussiansCPU::apply( p[0], // xys p[1], // radii, p[2], // conics @@ -168,6 +168,8 @@ int main(int argc, char **argv){ height, width, background); + outImg = rgba[0]; + outAlpha = rgba[1]; }else{ #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS) auto p = ProjectGaussians::apply(means, scales, 1, @@ -179,7 +181,7 @@ int main(int argc, char **argv){ width, tileBounds); - outImg = RasterizeGaussians::apply( + tensor_list rgba = RasterizeGaussians::apply( p[0], // xys p[1], // depths p[2], // radii, @@ -190,6 +192,8 @@ int main(int argc, char **argv){ height, width, background); + outImg = rgba[0]; + outAlpha = rgba[1]; #else throw std::runtime_error("GPU support not built, use --cpu"); #endif diff --git a/ssim.cpp b/ssim.cpp index f8a152e6..f1edd965 100644 --- a/ssim.cpp +++ b/ssim.cpp @@ -5,10 +5,18 @@ using namespace torch::indexing; -torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt) { +torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask) { torch::Tensor img1 = gt.permute({2, 0, 1}).index({None, "..."}); torch::Tensor img2 = rendered.permute({2, 0, 1}).index({None, "..."}); + + if (mask.numel() > 0){ + torch::Tensor ssimMask = mask.permute({2, 0, 1}).index({None, "..."}); + ssimMask = ssimMask.expand_as(img1); + img1 = img1 * ssimMask; + img2 = img2 * ssimMask; + } + if (img1.device() != window.device()){ window = window.to(img1.device()); } @@ -22,7 +30,7 @@ torch::Tensor SSIM::eval(const torch::Tensor& rendered, const torch::Tensor& gt) torch::Tensor sigma1Sq = torch::nn::functional::conv2d(img1 * img1, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu1Sq; torch::Tensor sigma2Sq = torch::nn::functional::conv2d(img2 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu2Sq; torch::Tensor sigma12 = torch::nn::functional::conv2d(img1 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(windowSize / 2).groups(channel)) - mu1mu2; - + const float C1 = 0.01 * 0.01; const float C2 = 0.03 * 0.03; diff --git a/ssim.hpp b/ssim.hpp index b74a86df..abca91a7 100644 --- a/ssim.hpp +++ b/ssim.hpp @@ -12,7 +12,7 @@ class SSIM{ window = createWindow(); }; - torch::Tensor eval(const torch::Tensor& rendered, const torch::Tensor& gt); + torch::Tensor eval(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& mask = torch::Tensor()); private: torch::Tensor createWindow(); torch::Tensor gaussian(float sigma);