LibtorchSegmentation - A c++ trainable semantic segmentation library based on libtorch (pytorch c++). Backbone: VGG, ResNet, ResNext. Architecture: FPN, U-Net, PAN, LinkNet, PSPNet, DeepLab-V3, DeepLab-V3+ by now.

Overview

English | 中文

logo
C++ library with Neural Networks for Image
Segmentation based on LibTorch.

Please give a star if this project helps you.

The main features of this library are:

  • High level API (just a line to create a neural network)
  • 7 models architectures for binary and multi class segmentation (including legendary Unet)
  • 15 available encoders
  • All encoders have pre-trained weights for faster and better convergence
  • 35% or more inference speed boost compared with pytorch cuda, same speed for cpu. (Unet tested in rtx 2070s).

📚 Libtorch Tutorials 📚

Visit Libtorch Tutorials Project if you want to know more about Libtorch Segment library.

📋 Table of content

  1. Quick start
  2. Examples
  3. Train your own data
  4. Models
    1. Architectures
    2. Encoders
  5. Installation
  6. Thanks
  7. To do list
  8. Citing
  9. License
  10. Related repository

Quick start

1. Create your first Segmentation model with Libtorch Segment

A resnet34 trochscript file is provided here. Segmentation model is just a LibTorch torch::nn::Module, which can be created as easy as:

#include "Segmentor.h"
auto model = UNet(1, /*num of classes*/
                  "resnet34", /*encoder name, could be resnet50 or others*/
                  "path to resnet34.pt"/*weight path pretrained on ImageNet, it is produced by torchscript*/
                  );
  • see table with available model architectures
  • see table with available encoders and their corresponding weights

2. Generate your own pretrained weights

All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). And you can also train only the decoder and segmentation head while freeze the backbone.

import torch
from torchvision import models

# resnet34 for example
model = models.resnet34(pretrained=True)
model.eval()
var=torch.ones((1,3,224,224))
traced_script_module = torch.jit.trace(model, var)
traced_script_module.save("resnet34.pt")

Congratulations! You are done! Now you can train your model with your favorite backbone and segmentation framework.

💡 Examples

  • Training model for person segmentation using images from PASCAL VOC Dataset. "voc_person_seg" dir contains 32 json labels and their corresponding jpeg images for training and 8 json labels with corresponding images for validation.
Segmentor<FPN> segmentor;
segmentor.Initialize(0/*gpu id, -1 for cpu*/,
                    512/*resize width*/,
                    512/*resize height*/,
                    {"background","person"}/*class name dict, background included*/,
                    "resnet34"/*backbone name*/,
                    "your path to resnet34.pt");
segmentor.Train(0.0003/*initial leaning rate*/,
                300/*training epochs*/,
                4/*batch size*/,
                "your path to voc_person_seg",
                ".jpg"/*image type*/,
                "your path to save segmentor.pt");
  • Predicting test. A segmentor.pt file is provided in the project here. It is trained through a FPN with ResNet34 backbone for a few epochs. You can directly test the segmentation result through:
cv::Mat image = cv::imread("your path to voc_person_seg\\val\\2007_004000.jpg");
Segmentor<FPN> segmentor;
segmentor.Initialize(0,512,512,{"background","person"},
                      "resnet34","your path to resnet34.pt");
segmentor.LoadWeight("segmentor.pt"/*the saved .pt path*/);
segmentor.Predict(image,"person"/*class name for showing*/);

the predicted result shows as follow:

🧑‍🚀 Train your own data

  • Create your own dataset. Using labelme through "pip install" and label your images. Split the output json files and images into folders just like below:
Dataset
├── train
│   ├── xxx.json
│   ├── xxx.jpg
│   └......
├── val
│   ├── xxxx.json
│   ├── xxxx.jpg
│   └......
  • Training or testing. Just like the example of "voc_person_seg", replace "voc_person_seg" with your own dataset path.
  • Refer to training tricks to improve your final training performance.

📦 Models

Architectures

Encoders

  • ResNet
  • ResNext
  • VGG

The following is a list of supported encoders in the Libtorch Segment. All the encoders weights can be generated through torchvision except resnest. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights.

ResNet
Encoder Weights Params, M
resnet18 imagenet 11M
resnet34 imagenet 21M
resnet50 imagenet 23M
resnet101 imagenet 42M
resnet152 imagenet 58M
ResNeXt
Encoder Weights Params, M
resnext50_32x4d imagenet 22M
resnext101_32x8d imagenet 86M
ResNeSt
Encoder Weights Params, M
timm-resnest14d imagenet 8M
timm-resnest26d imagenet 15M
timm-resnest50d imagenet 25M
timm-resnest101e imagenet 46M
timm-resnest200e imagenet 68M
timm-resnest269e imagenet 108M
timm-resnest50d_4s2x40d imagenet 28M
timm-resnest50d_1s4x24d imagenet 23M
SE-Net
Encoder Weights Params, M
senet154 imagenet 113M
se_resnet50 imagenet 26M
se_resnet101 imagenet 47M
se_resnet152 imagenet 64M
se_resnext50_32x4d imagenet 25M
se_resnext101_32x4d imagenet 46M
VGG
Encoder Weights Params, M
vgg11 imagenet 9M
vgg11_bn imagenet 9M
vgg13 imagenet 9M
vgg13_bn imagenet 9M
vgg16 imagenet 14M
vgg16_bn imagenet 14M
vgg19 imagenet 20M
vgg19_bn imagenet 20M

🛠 Installation

Dependency:

Windows:

Configure the environment for libtorch development. Visual studio and Qt Creator are verified for libtorch1.7x release.

Linux && MacOS:

Install libtorch and opencv.

For libtorch, follow the official pytorch c++ tutorials here.

For opencv, follow the official opencv install steps here.

If you have already configured them both, congratulations!!! Download the pretrained weight here and a demo .pt file here into weights.

Change the CMAKE_PREFIX_PATH to your own in CMakeLists.txt. Change the image path, pretrained path and segmentor path to your own in src/main.cpp. Then just in build folder, open the terminal, do the following:

cd build
cmake ..
make
./LibtorchSegmentation

ToDo

  • More segmentation architectures and backbones
    • UNet++ [paper]
    • ResNest
    • Se-Net
    • ...
  • Data augmentations
    • Random horizontal flip
    • Random vertical flip
    • Random scale rotation
    • ...
  • Training tricks
    • Combined dice and cross entropy loss
    • Freeze backbone
    • Multi step learning rate schedule
    • ...

🤝 Thanks

By now, these projects helps a lot.

📝 Citing

@misc{Chunyu:2021,
  Author = {Chunyu Dong},
  Title = {Libtorch Segment},
  Year = {2021},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/AllentDan/SegmentationCpp}}
}

🛡️ License

Project is distributed under MIT License.

Related repository

Based on libtorch, I released following repositories:

Last but not least, don't forget your star...

Feel free to commit issues or pull requests, contributors wanted.

stargazers over time

Comments
  • How to enable GPU inference and how t train for multiple objects

    How to enable GPU inference and how t train for multiple objects

    Hi,

    Thank you so much for your support with my last issue. Now I can able to run run the given code in Linux platform also.

    Please let me know how to do inference with GPU. When try to do inference using video input it is very slow. Hope GPU usage will improve the efficiency.

    It is possible to train multiple objects like people, animals etc.. in a model.

    Thanks and regards, Prabhakar M

    opened by mprabhakarece 10
  • Using DeepLabV3+?

    Using DeepLabV3+?

    Hi, and thank you very much for making this code available. Very helpful!

    I have the examples running great, but now I would like to train and use DeepLabV3+ model.

    How do I generate weights and train using this model?

    Thanks!

    opened by antithing 4
  • a bug about saving the best checkpoint.

    a bug about saving the best checkpoint.

    https://github.com/AllentDan/LibtorchSegmentation/blob/f872f93507c4e11ab02b1b7420d5030741775744/src/Segmentor.h#L116 it seems that the min loss should be declared before training for loop.

    opened by bigbigxing823 3
  • 关于读入预训练模型的一些问题

    关于读入预训练模型的一些问题

    您好!学习了您的基于libtorch写的语义分割算法,我感觉收益匪浅。 在源码上还有一些问题希望能得到解答。在预测的时候看到initialize读入了resnet34.pt模型,然后在LoadWeight时又读取了segmentor.pt模型,这两个模型是一样的吗?在读入segmentor.pt时model这个类参数是不是会被覆盖,如果会被覆盖的话之前读入的resnet34这个模型好像没有必要?还有一个问题就是,训练的时候也要读入骨干网络的预训练模型,训练后保存的模型是不是已经包括了骨干网络,或者是只有一个分割头,骨干网络沿用预训练的网络参数?

    还望您百忙之中能够解答,非常感谢!

    opened by CGump 2
  • fix error use 'template' keyword to treat 'item'

    fix error use 'template' keyword to treat 'item'

    Compile with CMAKE_CXX_STANDARD 17; Apple clang version 12.0.0 (clang-1200.0.32.29), getting an error:

    ../../Segmentor.h:187:21: error: use 'template' keyword to treat 'item' as a dependent template name loss_sum += loss.item<float>();

    According to ISO C++03 14.2/4 fix it.

    When the name of a member template specialization appears after . or -> in a postfix-expression, or after nested-name-specifier in a qualified-id, and the postfix-expression or qualified-id explicitly depends on a template-parameter (14.6.2), the member template name must be prefixed by the keyword template. Otherwise the name is assumed to name a non-template.

    opened by 0x0000dead 2
  • Instance Segmentation

    Instance Segmentation

    It was very helpful to implement segmentation using libTorch with C++. Thank you. I'm trying to implement Mask R-CNN for Instance Segmentation, but I'm not fully familiar with LibTorch yet, so I'm having a hard time. I'd like to include it in your code, so could you give me some hints?

    opened by Stellarto 2
  • UNet++ model

    UNet++ model

    Hi! Great repo!

    I was exactly looking for a C++ interface to Qubvel pytorch models. Is there any possibility that you will provide the interface for the Unet++ model with SE-ResNet-50 encoder?

    Thank you!

    enhancement question 
    opened by AlessandroSaviolo 2
  • How to train on custom dataset without json file?

    How to train on custom dataset without json file?

    Hi, my dataset has images and corresponding masks, but it has no json files. The dataset folders just like below:

    Dataset
    ├── train
    │   ├── 0.png
    │   ├── 0_mask.png
    │   └......
    ├── val
    │   ├── 3.png
    │   ├── 3_mask.png
    │   └......
    

    May I ask how to train on this dataset without json file? Thanks!

    opened by panovr 1
  • Adding DANNet to the supported Archs

    Adding DANNet to the supported Archs

    Hi, Great work guys!

    This is a great work indeed and I would like to contribute into it by adding a support for another network. https://github.com/W-zx-Y/DANNet. Could you let me know the process and where is better to start? I'm sure you have done this before so I could use your experience instead of going completely on my own.

    Thanks!

    opened by cassini-fly 1
  • build issue for linux

    build issue for linux

    Hi,

    Firstly, thank you for sharing this project. It would be very helpful to work with Pytorch in C++ using this library.

    I tried to build the project using the commands mentioned as follows: cd build cmake .. make ./LibtorchSegmentation

    The "cmake .." command is executed successfully. However, the "make" command shows an error message "make: *** No targets specified and no makefile found. Stop." There is no makefile in the LibtorchSegmentation project directory. And therefore, there is no LibtorchSegmentation file generated in the build directory.

    Please would you let me know how I can work around this?

    Thanks!

    opened by surajkhochare 1
  • 关于模型结构

    关于模型结构

    你好,看了下unet结构的实施,几个问题请教下:

    • 分割头上最后为什么还接了一层upsample,按道理conv出来的特征就是(b,n_class,h,w)已经恢复到原图大小了,那这层上采样貌似不是必要的。
    • dice_loss的实现:发现dice_coef的公式用了dice_coef=inter/(pred_area + label_area - inter),但这不是iou的公式吗?dice_coef不是应该等于2*inter/(pred_area + label_area)吗?
    opened by ximitiejiang 3
Owner
null
C++ trainable detection library based on libtorch (or pytorch c++). Yolov4 tiny provided now.

C++ Library with Neural Networks for Object Detection Based on LibTorch. ?? Libtorch Tutorials ?? Visit Libtorch Tutorials Project if you want to know

null 62 Dec 29, 2022
GA-NET: Global Attention Network for Point Cloud Semantic Segmentation

GA-NET: Global Attention Network for Point Cloud Semantic Segmentation We propose a global attention network, called GA-Net, to obtain global informat

null 4 Jul 18, 2022
ResNet Implementation, Training, and Inference Using LibTorch C++ API

LibTorch C++ ResNet CIFAR Example Introduction ResNet implementation, training, and inference using LibTorch C++ API. Because there is no native imple

Lei Mao 23 Oct 29, 2022
This is a code repository for pytorch c++ (or libtorch) tutorial.

LibtorchTutorials English version 环境 win10 visual sutdio 2017 或者Qt4.11.0 Libtorch 1.7 Opencv4.5 配置 libtorch+Visual Studio和libtorch+QT分别记录libtorch在VS和Q

null 464 Jan 9, 2023
Support Yolov4/Yolov3/Centernet/Classify/Unet. use darknet/libtorch/pytorch to onnx to tensorrt

ONNX-TensorRT Yolov4/Yolov3/CenterNet/Classify/Unet Implementation Yolov4/Yolov3 centernet INTRODUCTION you have the trained model file from the darkn

null 172 Dec 29, 2022
segformer semantic segmentation infer by tengine

Segformer-tengine Segformer semantic segmentation infer by tengine 前言: 记录一下Segformer部署在tengine上的折腾过程 - 小飞飞的文章 - 知乎 https://zhuanlan.zhihu.com/p/397735

null 9 Jun 18, 2022
TensorRT implementation of RepVGG models from RepVGG: Making VGG-style ConvNets Great Again

RepVGG RepVGG models from "RepVGG: Making VGG-style ConvNets Great Again" https://arxiv.org/pdf/2101.03697.pdf For the Pytorch implementation, you can

weiwei zhou 69 Sep 10, 2022
Very Fast VGG

VeryFastVGG It's faster than me! Build Make sure you have gcc, make, cmake, openmp and python installed. Use commands as follows, and the binary file

p0ny 2 Dec 31, 2021
The MOT implement by Solov2+DeepSORT with C++ (Libtorch, TensorRT).

Tracking-Solov2-Deepsort This project implement the Multi-Object-Tracking(MOT) base on SOLOv2 and DeepSORT with C++。 The instance segmentation model S

ChenJianqu 38 Dec 22, 2022
The optical flow algorithm RAFT implemented with C++(Libtorch+TensorRT)

RAFT_CPP Attention/注意 There are some bug here,output the wrong result 代码存在bug,估计出来的光流值不准确,解决中 Quick Start 0.Export RAFT onnx model 首先加载训练完成的模型权重: pars

ChenJianqu 21 Dec 29, 2022
A 3D DNN-based Metric Semantic Dense Mapping pipeline and a Visual Inertial SLAM system

MSDM-SLAM This repository represnets a 3D DNN-based Metric Semantic Dense Mapping pipeline and a Visual Inertial SLAM system that can be run on a grou

ITMO Biomechatronics and Energy Efficient Robotics Laboratory 11 Jul 23, 2022
Official page of "Patchwork: Concentric Zone-based Region-wise Ground Segmentation with Ground Likelihood Estimation Using a 3D LiDAR Sensor"

Patchwork Official page of "Patchwork: Concentric Zone-based Region-wise Ground Segmentation with Ground Likelihood Estimation Using a 3D LiDAR Sensor

Hyungtae Lim 252 Dec 21, 2022
An implementation on Fast Ground Segmentation for 3D LiDAR Point Cloud Based on Jump-Convolution-Process.

An implementation on "Shen Z, Liang H, Lin L, Wang Z, Huang W, Yu J. Fast Ground Segmentation for 3D LiDAR Point Cloud Based on Jump-Convolution-Process. Remote Sensing. 2021; 13(16):3239. https://doi.org/10.3390/rs13163239"

Wangxu1996 59 Jan 5, 2023
Ground segmentation and point cloud clustering based on CVC(Curved Voxel Clustering)

my_detection Ground segmentation and point cloud clustering based on CVC(Curved Voxel Clustering) 本项目使用设置地面坡度阈值的方法,滤除地面点,使用三维弯曲体素聚类法完成点云的聚类,包围盒参数由Apol

null 9 Jul 15, 2022
TensorFlow implementation of SQN based on RandLA-Net's encoder

SQN_tensorflow This repo is a TensorFlow implementation of Semantic Query Network (SQN). For Pytorch implementation, check our SQN_pytorch repo. Our i

PointCloudYC 9 Nov 3, 2022
An inofficial PyTorch implementation of PREDATOR based on KPConv.

PREDATOR: Registration of 3D Point Clouds with Low Overlap An inofficial PyTorch implementation of PREDATOR based on KPConv. The code has been tested

ZhuLifa 14 Aug 3, 2022
A LLVM-based static analyzer to produce PyTorch operator dependency graph.

What is this? This is a clone of the deprecated LLVM-based static analyzer from the PyTorch repo, which can be used to produce the PyTorch operator de

Jiakai Liu 5 Dec 15, 2021
Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation.

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions.

Alex Pashevich 61 Nov 17, 2022
⚡️Real-time portrait segmentation

ncnn-portrait-segmentation ⚡️ Real-time portrait segmentation This project provides real-time human segmentation based on CPU. Requirements ncnn openc

Youngsoo Lee 39 Dec 21, 2022