Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions cv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ cv::Mat imreadRGB(const std::string &filename){
return cImg;
}

cv::Mat imreadMask(const std::string &filename){
cv::Mat mask = cv::imread(filename, cv::IMREAD_GRAYSCALE);
if (mask.empty()){
std::cerr << "Cannot read mask " << filename << std::endl;
exit(1);
}
return mask;
}

void imwriteRGB(const std::string &filename, const cv::Mat &image){
cv::Mat rgb;
cv::cvtColor(image, rgb, cv::COLOR_RGB2BGR);
Expand Down Expand Up @@ -48,3 +57,22 @@ torch::Tensor imageToTensor(const cv::Mat &image){
return (img.toType(torch::kFloat32) / 255.0f);
}

torch::Tensor maskToTensor(const cv::Mat &mask){
torch::Tensor m = torch::from_blob(mask.data, { mask.rows, mask.cols, 1 }, torch::kU8);

// Binary mask: threshold at 128, output 0.0 or 1.0
return (m.toType(torch::kFloat32) / 255.0f).ge(0.5f).toType(torch::kFloat32);
}

cv::Mat tensorToMask(const torch::Tensor &t){
int h = t.sizes()[0];
int w = t.sizes()[1];

cv::Mat mask(h, w, CV_8UC1);
torch::Tensor scaledTensor = (t.squeeze() * 255.0).toType(torch::kU8);
uint8_t* dataPtr = static_cast<uint8_t*>(scaledTensor.data_ptr());
std::copy(dataPtr, dataPtr + (w * h), mask.data);

return mask;
}

3 changes: 3 additions & 0 deletions cv_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include <opencv2/imgproc.hpp>

cv::Mat imreadRGB(const std::string &filename);
cv::Mat imreadMask(const std::string &filename);
void imwriteRGB(const std::string &filename, const cv::Mat &image);
cv::Mat floatNxNtensorToMat(const torch::Tensor &t);
torch::Tensor floatNxNMatToTensor(const cv::Mat &m);
cv::Mat tensorToImage(const torch::Tensor &t);
torch::Tensor imageToTensor(const cv::Mat &image);
torch::Tensor maskToTensor(const cv::Mat &mask);
cv::Mat tensorToMask(const torch::Tensor &t);

#endif
36 changes: 36 additions & 0 deletions input_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ void Camera::loadImage(float downscaleFactor){
fy = K[1][1].item<float>();
cx = K[0][2].item<float>();
cy = K[1][2].item<float>();

// Load mask if path is set
if (!maskPath.empty()){

std::cout << "Loading mask " << maskPath << std::endl;

cv::Mat cMask = imreadMask(maskPath);

// Resize mask to match image dimensions
if (cMask.rows != height || cMask.cols != width){
cv::resize(cMask, cMask, cv::Size(width, height), 0.0, 0.0, cv::INTER_NEAREST);
}

mask = maskToTensor(cMask);
}
}

torch::Tensor Camera::getImage(int downscaleFactor){
Expand All @@ -116,6 +131,27 @@ torch::Tensor Camera::getImage(int downscaleFactor){
}
}

torch::Tensor Camera::getMask(int downscaleFactor){
if (!mask.numel()) return torch::Tensor(); // No mask available

if (downscaleFactor <= 1) return mask;

if (maskPyramids.find(downscaleFactor) != maskPyramids.end()){
return maskPyramids[downscaleFactor];
}

// Rescale using nearest neighbor (preserve binary values)
cv::Mat cMask = tensorToMask(mask);
cv::resize(cMask, cMask, cv::Size(cMask.cols / downscaleFactor, cMask.rows / downscaleFactor), 0.0, 0.0, cv::INTER_NEAREST);
torch::Tensor t = maskToTensor(cMask);
maskPyramids[downscaleFactor] = t;
return t;
}

bool Camera::hasMask() const {
return mask.numel() > 0;
}

bool Camera::hasDistortionParameters(){
return k1 != 0.0f || k2 != 0.0f || k3 != 0.0f || p1 != 0.0f || p2 != 0.0f;
}
Expand Down
5 changes: 5 additions & 0 deletions input_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct Camera{
float p2 = 0;
torch::Tensor camToWorld;
std::string filePath = "";
std::string maskPath = ""; // Optional path to mask file
CameraType cameraType = CameraType::Perspective;

Camera(){};
Expand All @@ -37,12 +38,16 @@ struct Camera{
bool hasDistortionParameters();
std::vector<float> undistortionParameters();
torch::Tensor getImage(int downscaleFactor);
torch::Tensor getMask(int downscaleFactor);
bool hasMask() const;

void loadImage(float downscaleFactor);
torch::Tensor K;
torch::Tensor image;
torch::Tensor mask; // Optional mask tensor [H,W,1], 0=exclude, 1=include

std::unordered_map<int, torch::Tensor> imagePyramids;
std::unordered_map<int, torch::Tensor> maskPyramids; // Cached downscaled masks
};

struct Points{
Expand Down
50 changes: 39 additions & 11 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,31 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt){
return (10.f * torch::log10(1.0 / mse));
}

torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){
return torch::abs(gt - rendered).mean();
//define these somewhere
float __mask_opacity_penalty_power = 2.0,
__mask_opacity_penalty_weight = 1.0;

torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& alpha, const torch::Tensor& mask){
torch::Tensor diff = torch::abs(gt - rendered),
maskPenalty = torch::zeros_like(diff);


if (mask.numel() > 0){
torch::Tensor expandedMask = mask.expand_as(diff);

//maskPenalty = torch::zeros_like(expandedMask)

const torch::Tensor bg_mask = 1.0 - expandedMask,
penalty_weights = bg_mask.pow(__mask_opacity_penalty_power),
penalty = (expandedMask * penalty_weights).mean() * __mask_opacity_penalty_weight; // de penalty wordt bepaald aan de hand penalty gewicht * penalty_weights

const float inv_pixels = __mask_opacity_penalty_weight / static_cast<float>(expandedMask.numel()); //deel de opacity weight door het aantal elementen in de alpha lijst
maskPenalty = penalty_weights * inv_pixels; //Grad alpha wordt dan de 2d penalty kaart teruggebracht tot een
//loss = loss + penalty;

// hier moet ik alpha in krijgen zodat ik met de inverse daarvan hier vervolgens het gradient mee kan maken
}
return (diff + maskPenalty).mean();
}

void Model::setupOptimizers(){
Expand Down Expand Up @@ -80,7 +103,7 @@ void Model::releaseOptimizers(){
}


torch::Tensor Model::forward(Camera& cam, int step){
tensor_list Model::forward(Camera& cam, int step){

const float scaleFactor = getDownscaleFactor(step);
const float fx = cam.fx / scaleFactor;
Expand Down Expand Up @@ -118,7 +141,7 @@ torch::Tensor Model::forward(Camera& cam, int step){
torch::Tensor numTilesHit; // GPU-only
torch::Tensor cov2d; // CPU-only
torch::Tensor camDepths; // CPU-only
torch::Tensor rgb;
torch::Tensor rgb, alpha;

if (device == torch::kCPU){
auto p = ProjectGaussiansCPU::apply(means,
Expand Down Expand Up @@ -171,7 +194,7 @@ torch::Tensor Model::forward(Camera& cam, int step){
xys.retain_grad();

if (radii.sum().item<float>() == 0.0f)
return backgroundColor.repeat({height, width, 1});
return { backgroundColor.repeat({height, width, 1}), torch::zeros({height, width, 1}) };

torch::Tensor viewDirs = means.detach() - T.transpose(0, 1).to(device);
viewDirs = viewDirs / viewDirs.norm(2, {-1}, true);
Expand All @@ -192,7 +215,7 @@ torch::Tensor Model::forward(Camera& cam, int step){
rgbs = torch::clamp_min(rgbs + 0.5f, 0.0f);

if (device == torch::kCPU){
rgb = RasterizeGaussiansCPU::apply(
auto rgba = RasterizeGaussiansCPU::apply(
xys,
radii,
conics,
Expand All @@ -203,9 +226,11 @@ torch::Tensor Model::forward(Camera& cam, int step){
height,
width,
backgroundColor);
rgb = rgba[0];
alpha = rgba[1];
}else{
#if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS)
rgb = RasterizeGaussians::apply(
auto rgba = RasterizeGaussians::apply(
xys,
depths,
radii,
Expand All @@ -216,12 +241,15 @@ torch::Tensor Model::forward(Camera& cam, int step){
height,
width,
backgroundColor);
rgb = rgba[0];
alpha = rgba[1];
#endif
}

rgb = torch::clamp_max(rgb, 1.0f);
alpha = torch::clamp_max(alpha, 1.0f);

return rgb;
return { rgb, alpha };
}

void Model::optimizersZeroGrad(){
Expand Down Expand Up @@ -777,8 +805,8 @@ int Model::loadPly(const std::string &filename){
throw std::runtime_error("Invalid PLY file");
}

torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt);
torch::Tensor l1Loss = l1(rgb, gt);
torch::Tensor Model::mainLoss(torch::Tensor &rgb, const torch::Tensor& alpha, torch::Tensor &gt, float ssimWeight, const torch::Tensor &mask){
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt, mask);
torch::Tensor l1Loss = l1(rgb, gt, alpha, mask);
return (1.0f - ssimWeight) * l1Loss + ssimWeight * ssimLoss;
}
6 changes: 3 additions & 3 deletions model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace torch::autograd;
torch::Tensor randomQuatTensor(long long n);
torch::Tensor projectionMatrix(float zNear, float zFar, float fovX, float fovY, const torch::Device &device);
torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt);
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt, const torch::Tensor& alpha, const torch::Tensor& mask = torch::Tensor());

struct Model{
Model(const InputData &inputData, int numCameras,
Expand Down Expand Up @@ -63,7 +63,7 @@ struct Model{
void setupOptimizers();
void releaseOptimizers();

torch::Tensor forward(Camera& cam, int step);
tensor_list forward(Camera& cam, int step);
void optimizersZeroGrad();
void optimizersStep();
void schedulersStep(int step);
Expand All @@ -74,7 +74,7 @@ struct Model{
void saveSplat(const std::string &filename);
void saveDebugPly(const std::string &filename, int step);
int loadPly(const std::string &filename);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight);
torch::Tensor mainLoss(torch::Tensor &rgb, const torch::Tensor& alpha, torch::Tensor &gt, float ssimWeight, const torch::Tensor &mask = torch::Tensor());

void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
void removeFromOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &deletedMask);
Expand Down
64 changes: 56 additions & 8 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ int main(int argc, char *argv[]){
("stop-screen-size-at", "Stop splitting gaussians that are larger than [split-screen-size] after these many steps", cxxopts::value<int>()->default_value("4000"))
("split-screen-size", "Split gaussians that are larger than this percentage of screen space", cxxopts::value<float>()->default_value("0.05"))
("colmap-image-path", "Override the default image path for COLMAP-based input", cxxopts::value<std::string>()->default_value(""))
("mask-dir", "Path to directory containing mask images (binary: 0=exclude, 1=include)", cxxopts::value<std::string>()->default_value(""))
#ifdef USE_VISUALIZATION
("has-visualization", "Show the visualization steps of training", cxxopts::value<bool>()->default_value("0"))
#endif
Expand Down Expand Up @@ -95,6 +96,7 @@ int main(int argc, char *argv[]){
const int stopScreenSizeAt = result["stop-screen-size-at"].as<int>();
const float splitScreenSize = result["split-screen-size"].as<float>();
const std::string colmapImageSourcePath = result["colmap-image-path"].as<std::string>();
const std::string maskDir = result["mask-dir"].as<std::string>();
#ifdef USE_VISUALIZATION
const bool hasVisualization = result["has-visualization"].as<bool>();
#endif
Expand All @@ -121,6 +123,35 @@ int main(int argc, char *argv[]){
try{
InputData inputData = inputDataFromX(projectRoot, colmapImageSourcePath);

// Set mask paths if mask directory is provided
if (!maskDir.empty()){
fs::path maskDirPath(maskDir);
if (!fs::exists(maskDirPath)){
std::cerr << "Mask directory does not exist: " << maskDir << std::endl;
exit(1);
}

for (Camera &cam : inputData.cameras){
fs::path imagePath(cam.filePath);
std::string imageName = imagePath.filename().string(),
imageStem = imagePath.stem().string();

// Try common mask extensions
for(const std::string & name : { imageName, imageStem})
for (const std::string &ext : {".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"}){
fs::path maskPath = maskDirPath / (name + ext);
if (fs::exists(maskPath)){
cam.maskPath = maskPath.string();
break;
}
}

if (cam.maskPath.empty()){
std::cerr << "Warning: No mask found for " << cam.filePath << std::endl;
}
}
}

parallel_for(inputData.cameras.begin(), inputData.cameras.end(), [&downScaleFactor](Camera &cam){
cam.loadImage(downScaleFactor);
});
Expand Down Expand Up @@ -153,11 +184,20 @@ int main(int argc, char *argv[]){

model.optimizersZeroGrad();

torch::Tensor rgb = model.forward(cam, step);
torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step));
gt = gt.to(device);
tensor_list rgba = model.forward(cam, step);
torch::Tensor gt = cam.getImage(model.getDownscaleFactor(step)),
rgb = rgba[0],
alpha = rgba[1];

gt = gt.to(device);

torch::Tensor mask;
if (cam.hasMask()){
mask = cam.getMask(model.getDownscaleFactor(step));
mask = mask.to(device);
}

torch::Tensor mainLoss = model.mainLoss(rgb, gt, ssimWeight);
torch::Tensor mainLoss = model.mainLoss(rgb, gt, alpha, ssimWeight, mask);
mainLoss.backward();

if (step % displayStep == 0) {
Expand All @@ -175,7 +215,9 @@ int main(int argc, char *argv[]){
}

if (!valRender.empty() && step % 10 == 0){
torch::Tensor rgb = model.forward(*valCam, step);
tensor_list rgba = model.forward(*valCam, step);
torch::Tensor rgb = rgba[0],
alpha = rgba[1];
cv::Mat image = tensorToImage(rgb.detach().cpu());
cv::cvtColor(image, image, cv::COLOR_RGB2BGR);
cv::imwrite((fs::path(valRender) / (std::to_string(step) + ".png")).string(), image);
Expand All @@ -201,9 +243,15 @@ int main(int argc, char *argv[]){

// Validate
if (valCam != nullptr){
torch::Tensor rgb = model.forward(*valCam, numIters);
torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device);
std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, gt, ssimWeight).item<float>() << std::endl;
tensor_list rgba = model.forward(*valCam, numIters);
torch::Tensor rgb = rgba[0],
alpha = rgba[1];
torch::Tensor gt = valCam->getImage(model.getDownscaleFactor(numIters)).to(device);
torch::Tensor valMask;
if (valCam->hasMask()){
valMask = valCam->getMask(model.getDownscaleFactor(numIters)).to(device);
}
std::cout << valCam->filePath << " validation loss: " << model.mainLoss(rgb, alpha, gt, ssimWeight, valMask).item<float>() << std::endl;
}
}catch(const std::exception &e){
std::cerr << e.what() << std::endl;
Expand Down
Loading