Location>code7788 >text

Segment-anything learning to fine-tune series 3_SAM fine-tuning decoder

Popularity:442 ℃/2024-07-29 15:34:02

preamble

This three-part series is a collection of study notes from bloggers using the SAM model at work:

  1. SAM initial understanding, brief introduction to the modeling framework, without going into details and code
  2. SAM detail understanding, further analysis of the modules in conjunction with the code
  3. SAM fine-tuning example, the original code involves privacy, this part uses the publicly available VOC2007 dataset, Point and Box as hints for mask decoder fine-tuning explanation

This is part 3, fine-tuning the SAM decoder based on the voc2007 dataset. The code has been uploaded togithub, please click Star if it helps you, thanks.

As previously mentioned, the SAM weight based on ViT_B is 375M, of which the prompt encoder is only 32.8K, the mask decoder is 16.3M (4.35%), and the remaining is the image encoder. The image encoder is very large, and it is generally not fine-tuned, and the pre-training is already good enough. Unless it is similar to the medical image of such unconventional data, the pre-training data is not, the effect will be relatively poor, will also fine-tune the image encoder, so here only for the decoder fine-tuning.

fine-tuning effect

Based on point prompt

This part is fine-tuned for point only as a hint, with the help of theISAT_with_segment_anythingThis use SAM to do automatic labeling tools to carry out a comparison of the effect, you can see that before fine-tuning, you need to click many times more than one point in order to split the better, after fine-tuning can be split by clicking on a corresponding category

pre-fine-tuning

fine tuned

Based on box prompt

This section adds box as a hint to fine-tune the

pre-fine-tuning

fine tuned

code section

data retrieval

Using the VOC2007 segmentation dataset, a total of 632 images (412train_val, 210test), a total of 20 categories, plus the background class a total of 21, labels are in png format, the pixel value represents the object category, while the outer contour value of all the object mask is 255, which is ignored during training, the original dataset is constructed as follows directory (github) The original dataset is constructed as follows (data_example in the code on github is just an example, there are only a few images), and the labels in SegmentationObject are used for training:

## VOCdevkit/VOC2007
├── Annotations
├── ImageSets
│   ├── Layout
│   ├── Main
│   └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject

The code of CustomDataset reads the corresponding data according to the directory structure as above, specifies the name of the file for training according to txt_name in the ImageSets/Segmentation directory, and then reads the corresponding images and labels, with the following points to note:

  • Split tags are read using PIL, the pixel value is the corresponding category, 255 is the outer contour will be ignored; if you use opencv to read the image, you need to go to the platte table according to the RGB value to see the corresponding category
  • Both image and gt are stuffed into the batch by numpy array, and thrown to sam later will be converted to tensor;The size of each image in voc2007 is not consistent, so for now it's treated as batch=1
  • The channel of gt is 1, which needs to be converted to one-hot later on
class CustomDataset(Dataset):
    def __init__(self, VOCdevkit_path, txt_name="", transform=None):
        self.VOCdevkit_path = VOCdevkit_path
        with open((VOCdevkit_path, f"VOC2007/ImageSets/Segmentation/{txt_name}"), "r") as f:
            file_names = ()
        self.file_names = [() for name in file_names]
        self.image_dir = (self.VOCdevkit_path, "VOC2007/JPEGImages")
        self.image_files = [f"{self.image_dir}/{name}.jpg" for name in self.file_names]
        self.gt_dir = (self.VOCdevkit_path, "VOC2007/SegmentationObject")
        self.gt_files = [f"{self.gt_dir}/{name}.png" for name in self.file_names]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image_name = image_path.split("/")[-1]
        gt_path = self.gt_files[idx]

        image = (image_path)
        image = image[..., ::-1] ## RGB to BGR
        image = (image)
        gt = (gt_path)
        gt = (gt, dtype='uint8')
        gt = (gt)

        return image, gt, image_name

    @staticmethod
    def custom_collate(batch):
        """ DataLoadercentercollate_fn,
         imagery andgtcommon (use)numpyspecification,We'll retransfer it later.tensor
        """
        images = []
        seg_labels = []
        images_name = []
        for image, gt, image_name in batch:
            (image)
            seg_labels.append(gt)
            images_name.append(image_name)
        images = (images)
        seg_labels = (seg_labels)
        return images, seg_labels, images_name

Image Preprocessing

After getting the image, directly use the preprocessing method in SamPredictor, it will resize the image to 1024x1024 according to the longest side, and then calculate the image_embedding, this part is very time-consuming, so it is only calculated once for each image, and the result will be cached and called directly when needed. Use "with torch.no_grad()" to ensure that the image encoder part does not need to update the gradient, freezing the corresponding weights.

    model_transform = ResizeLongestSide(sam.image_encoder.img_size)
    for epoch in range(num_epochs):
        epoch_loss = 0
        for idx, (images, gts, image_names) in enumerate(tqdm(dataloader)):
            valid_classes = []  ## voc 0,255 are ignored
            for i in range([0]):
                image = images[i] # h,w,c np.uint8 rgb
                original_size = [:2] ## h,w
                input_size = model_transform.get_preprocess_shape([0], [1],
                                                                  sam.image_encoder.img_size)  ##h,w
                gt = gts[i].copy() #h,w labels [0,1,2,..., classes-1]
                gt_classes = (gt)  ##masks classes: [0, 1, 2, 3, 4, 7]
                image_name = image_names[i]

                predictions = []
                ## freeze image encoder
                with torch.no_grad():
                    # gt_channel = gt[:, :, cls]
                    predictor.set_image(image, "RGB")
                    image_embedding = predictor.get_image_embedding()

Prompt Generation

Randomly select a certain number of foreground and background points from the mask, here the default is 1 foreground point and 1 background point, if the number is large, generally keep the ratio of 2:1 is better.

mask_value is the corresponding category id, go to the mask and find out the coordinates of the points whose pixel value is equal to the category id, and then randomly select the points. Here, we will also calculate the external rectangle according to the mask (in fact, we can read the corresponding xml tag file of the image directly), which will be used for the subsequent finetune based on the box prompt.

def get_random_prompts(mask, mask_value, foreground_nums=1, background_nums=1):
    # Find the indices (coordinates) of the foreground pixels
    foreground_indices = (mask == mask_value)
    ymin, xmin= foreground_indices.min(axis=0)
    ymax, xmax = foreground_indices.max(axis=0)
    bbox = ([xmin, ymin, xmax, ymax])
    if foreground_indices.shape[0] < foreground_nums:
        foreground_nums = foreground_indices.shape[0]
        background_nums = int(0.5 * foreground_indices.shape[0])
    background_indices = (mask != mask_value)

    ## random select
    foreground_points = foreground_indices[
        (foreground_indices.shape[0], foreground_nums, replace=False)]
    background_points = background_indices[
        (background_indices.shape[0], background_nums, replace=False)]

    ## The coordinate points are(y,x),The input to the network should be(x,y),Need to flip the order.
    foreground_points = foreground_points[:, ::-1]
    background_points = background_points[:, ::-1]

    return (foreground_points, background_points), bbox

The prompt is the coordinates of some points, the x,y of the coordinates are based on the original image, but the image into the SAM will be resized to 1024x1024, so the coordinates of the points also need to be resized, corresponding to the following code

    all_points = ((foreground_points, background_points), axis=0)
    all_points = (all_points)
    point_labels = ([1] * foreground_points.shape[0] + [0] * background_points.shape[0], dtype=int)
    ## image resized to 1024, points also
    all_points = model_transform.apply_coords(all_points, original_size)

    all_points = torch.as_tensor(all_points, dtype=, device=device)
    point_labels = torch.as_tensor(point_labels, dtype=, device=device)
    all_points, point_labels = all_points[None, :, :], point_labels[None, :]
    points = (all_points, point_labels)

    if not box_prompt:
        box_torch=None
    else:
        ## preprocess bbox
        box = model_transform.apply_boxes(bbox, original_size)
        box_torch = torch.as_tensor(box, dtype=, device=device)
        box_torch = box_torch[None, :]

Fine-tuning code can be specified in the fine-tuning based on which kind of prompt, if the point and box are both open, will be discarded according to a certain probability of point or box to achieve better generalization (otherwise reasoning only point or only box as the prompt may not be very good). Finally, after prompt_encoder, we get the sparse_embeddings, dense_embeddings.

    ## if both, random drop one for better generalization ability
    if point_box and ()<0.5:
        if ()<0.25:
            points = None
        elif ()>0.75:
            box_torch = None
    ## freeze prompt encoder
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points = points,
            boxes = box_torch,
            # masks=mask_predictions,
            masks=None,
        )

Mask Forecast

mask decoder this part does not need to be frozen, directly call mask_decoder inference on it, here two mask prediction is performed, the first time first predict 3 layers of mask then select the one with the highest score, use this mask as a mask prompt, and with point prompt, box_prompt together Throw into prompt_encoder to get new sparse_embeddings, dense_embeddings, and then do the second mask prediction, this time only predict a mask. it is equivalent to get the rough mask first, and then refine it. Finally, after post-processing nms and other predicted masks of the same size as the original image, an object corresponds to a mask, multiple masks will be stacked up to get all the predictions predictions of this image.

    ## predicted masks, three level
    mask_predictions, scores = sam.mask_decoder(
        image_embeddings=image_embedding.to(device),
        image_pe=sam.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
    )
    # Choose the model's best mask
    mask_input = mask_predictions[:, (scores),...].unsqueeze(1)
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points=points,
            boxes=box_torch,
            masks=mask_input,
        )
        ## predict a better mask, only one mask
        mask_predictions, scores = sam.mask_decoder(
            image_embeddings=image_embedding.to(device),
            image_pe=sam.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )
        best_mask = sam.postprocess_masks(mask_predictions, input_size, original_size)
        (best_mask)

Loss calculation

The code uses BCELoss plus DiceLoss for the loss, which requires the same shape for gt and pred, both of the form BxCxHxW, and pred is the value after sigmoid.

Therefore it is necessary to convert gt to one-hot form, i.e., (batch_size, 1, h, w) to (batch_size, c, h, w), where c is the number of categories in gt_classes, i.e., how many instance categories are in the picture.

def mask2one_hot(label, gt_classes):
    """
    label: label image # (batch_size, 1, h, w)
    num_classes: Number of classification categories
    """
    current_label = (1) # (batch_size, 1, h, w) ---> (batch_size, h, w)
    batch_size, h, w = current_label.shape[0], current_label.shape[1], current_label.shape[2]
    one_hots = []
    for cls in gt_classes:
        if isinstance(cls, ):
            cls = ()
        tmplate = (batch_size, h, w) # (batch_size, h, w)
        tmplate[current_label == cls] = 1
        tmplate = (batch_size, 1, h, w) # (batch_size, h, w) --> (batch_size, 1, h, w)
        one_hots.append(tmplate)
    onehot = (one_hots, dim=1)
    return onehot

In addition, BCE accepts pred values in logit form, so predictions need to be processed with sigmoid, and the subsequent loss calculation corresponds to the following code

    gts = torch.from_numpy(gts).unsqueeze(1) ## BxHxW ---> Bx1xHxW
    gts_onehot = mask2one_hot(gts, valid_classes)
    gts_onehot = gts_onehot.to(device)

    predictions = (predictions)
    # #loss = seg_loss(predictions, gts_onehot)
    loss = BCEseg(predictions, gts_onehot)
    loss_dice = soft_dice_loss(predictions, gts_onehot, smooth = 1e-5, activation='none')
    loss = loss + loss_dice

Weight saving

The optimizer is AdamW by default, the scheduler is CosineAnnealingLR, these can be modified by yourself. The weights saved at the end are only the ones with the smallest current loss, and only the weights of the decoder part are saved, which can be modified as needed

if epoch_loss < best_loss:
    best_loss = epoch_loss
    mask_decoder_weighs = sam.mask_decoder.state_dict()
    mask_decoder_weighs = {f"mask_decoder.{k}": v for k,v in mask_decoder_weighs.items() }
    (mask_decoder_weighs, (save_dir, f'sam_decoder_fintune_{str(epoch+1)}_pointbox_monai.pth'))
    print("Saving weights, epoch: ", epoch+1)

Full series complete, thanks for reading...