preamble
This three-part series is a collection of study notes from bloggers using the SAM model at work:
- SAM initial understanding, brief introduction to the modeling framework, without going into details and code
- SAM detail understanding, further analysis of the modules in conjunction with the code
- SAM fine-tuning example, the original code involves privacy, this part uses the publicly available VOC2007 dataset, Point and Box as hints for mask decoder fine-tuning explanation
This is part 3, fine-tuning the SAM decoder based on the voc2007 dataset. The code has been uploaded togithub, please click Star if it helps you, thanks.
As previously mentioned, the SAM weight based on ViT_B is 375M, of which the prompt encoder is only 32.8K, the mask decoder is 16.3M (4.35%), and the remaining is the image encoder. The image encoder is very large, and it is generally not fine-tuned, and the pre-training is already good enough. Unless it is similar to the medical image of such unconventional data, the pre-training data is not, the effect will be relatively poor, will also fine-tune the image encoder, so here only for the decoder fine-tuning.
fine-tuning effect
Based on point prompt
This part is fine-tuned for point only as a hint, with the help of theISAT_with_segment_anythingThis use SAM to do automatic labeling tools to carry out a comparison of the effect, you can see that before fine-tuning, you need to click many times more than one point in order to split the better, after fine-tuning can be split by clicking on a corresponding category
pre-fine-tuning
fine tuned
Based on box prompt
This section adds box as a hint to fine-tune the
pre-fine-tuning
fine tuned
code section
data retrieval
Using the VOC2007 segmentation dataset, a total of 632 images (412train_val, 210test), a total of 20 categories, plus the background class a total of 21, labels are in png format, the pixel value represents the object category, while the outer contour value of all the object mask is 255, which is ignored during training, the original dataset is constructed as follows directory (github) The original dataset is constructed as follows (data_example in the code on github is just an example, there are only a few images), and the labels in SegmentationObject are used for training:
## VOCdevkit/VOC2007
├── Annotations
├── ImageSets
│ ├── Layout
│ ├── Main
│ └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject
The code of CustomDataset reads the corresponding data according to the directory structure as above, specifies the name of the file for training according to txt_name in the ImageSets/Segmentation directory, and then reads the corresponding images and labels, with the following points to note:
- Split tags are read using PIL, the pixel value is the corresponding category, 255 is the outer contour will be ignored; if you use opencv to read the image, you need to go to the platte table according to the RGB value to see the corresponding category
- Both image and gt are stuffed into the batch by numpy array, and thrown to sam later will be converted to tensor;The size of each image in voc2007 is not consistent, so for now it's treated as batch=1
- The channel of gt is 1, which needs to be converted to one-hot later on
class CustomDataset(Dataset):
def __init__(self, VOCdevkit_path, txt_name="", transform=None):
self.VOCdevkit_path = VOCdevkit_path
with open((VOCdevkit_path, f"VOC2007/ImageSets/Segmentation/{txt_name}"), "r") as f:
file_names = ()
self.file_names = [() for name in file_names]
self.image_dir = (self.VOCdevkit_path, "VOC2007/JPEGImages")
self.image_files = [f"{self.image_dir}/{name}.jpg" for name in self.file_names]
self.gt_dir = (self.VOCdevkit_path, "VOC2007/SegmentationObject")
self.gt_files = [f"{self.gt_dir}/{name}.png" for name in self.file_names]
def __len__(self):
return len(self.file_names)
def __getitem__(self, idx):
image_path = self.image_files[idx]
image_name = image_path.split("/")[-1]
gt_path = self.gt_files[idx]
image = (image_path)
image = image[..., ::-1] ## RGB to BGR
image = (image)
gt = (gt_path)
gt = (gt, dtype='uint8')
gt = (gt)
return image, gt, image_name
@staticmethod
def custom_collate(batch):
""" DataLoadercentercollate_fn,
imagery andgtcommon (use)numpyspecification,We'll retransfer it later.tensor
"""
images = []
seg_labels = []
images_name = []
for image, gt, image_name in batch:
(image)
seg_labels.append(gt)
images_name.append(image_name)
images = (images)
seg_labels = (seg_labels)
return images, seg_labels, images_name
Image Preprocessing
After getting the image, directly use the preprocessing method in SamPredictor, it will resize the image to 1024x1024 according to the longest side, and then calculate the image_embedding, this part is very time-consuming, so it is only calculated once for each image, and the result will be cached and called directly when needed. Use "with torch.no_grad()" to ensure that the image encoder part does not need to update the gradient, freezing the corresponding weights.
model_transform = ResizeLongestSide(sam.image_encoder.img_size)
for epoch in range(num_epochs):
epoch_loss = 0
for idx, (images, gts, image_names) in enumerate(tqdm(dataloader)):
valid_classes = [] ## voc 0,255 are ignored
for i in range([0]):
image = images[i] # h,w,c np.uint8 rgb
original_size = [:2] ## h,w
input_size = model_transform.get_preprocess_shape([0], [1],
sam.image_encoder.img_size) ##h,w
gt = gts[i].copy() #h,w labels [0,1,2,..., classes-1]
gt_classes = (gt) ##masks classes: [0, 1, 2, 3, 4, 7]
image_name = image_names[i]
predictions = []
## freeze image encoder
with torch.no_grad():
# gt_channel = gt[:, :, cls]
predictor.set_image(image, "RGB")
image_embedding = predictor.get_image_embedding()
Prompt Generation
Randomly select a certain number of foreground and background points from the mask, here the default is 1 foreground point and 1 background point, if the number is large, generally keep the ratio of 2:1 is better.
mask_value is the corresponding category id, go to the mask and find out the coordinates of the points whose pixel value is equal to the category id, and then randomly select the points. Here, we will also calculate the external rectangle according to the mask (in fact, we can read the corresponding xml tag file of the image directly), which will be used for the subsequent finetune based on the box prompt.
def get_random_prompts(mask, mask_value, foreground_nums=1, background_nums=1):
# Find the indices (coordinates) of the foreground pixels
foreground_indices = (mask == mask_value)
ymin, xmin= foreground_indices.min(axis=0)
ymax, xmax = foreground_indices.max(axis=0)
bbox = ([xmin, ymin, xmax, ymax])
if foreground_indices.shape[0] < foreground_nums:
foreground_nums = foreground_indices.shape[0]
background_nums = int(0.5 * foreground_indices.shape[0])
background_indices = (mask != mask_value)
## random select
foreground_points = foreground_indices[
(foreground_indices.shape[0], foreground_nums, replace=False)]
background_points = background_indices[
(background_indices.shape[0], background_nums, replace=False)]
## The coordinate points are(y,x),The input to the network should be(x,y),Need to flip the order.
foreground_points = foreground_points[:, ::-1]
background_points = background_points[:, ::-1]
return (foreground_points, background_points), bbox
The prompt is the coordinates of some points, the x,y of the coordinates are based on the original image, but the image into the SAM will be resized to 1024x1024, so the coordinates of the points also need to be resized, corresponding to the following code
all_points = ((foreground_points, background_points), axis=0)
all_points = (all_points)
point_labels = ([1] * foreground_points.shape[0] + [0] * background_points.shape[0], dtype=int)
## image resized to 1024, points also
all_points = model_transform.apply_coords(all_points, original_size)
all_points = torch.as_tensor(all_points, dtype=, device=device)
point_labels = torch.as_tensor(point_labels, dtype=, device=device)
all_points, point_labels = all_points[None, :, :], point_labels[None, :]
points = (all_points, point_labels)
if not box_prompt:
box_torch=None
else:
## preprocess bbox
box = model_transform.apply_boxes(bbox, original_size)
box_torch = torch.as_tensor(box, dtype=, device=device)
box_torch = box_torch[None, :]
Fine-tuning code can be specified in the fine-tuning based on which kind of prompt, if the point and box are both open, will be discarded according to a certain probability of point or box to achieve better generalization (otherwise reasoning only point or only box as the prompt may not be very good). Finally, after prompt_encoder, we get the sparse_embeddings, dense_embeddings.
## if both, random drop one for better generalization ability
if point_box and ()<0.5:
if ()<0.25:
points = None
elif ()>0.75:
box_torch = None
## freeze prompt encoder
with torch.no_grad():
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
points = points,
boxes = box_torch,
# masks=mask_predictions,
masks=None,
)
Mask Forecast
mask decoder this part does not need to be frozen, directly call mask_decoder inference on it, here two mask prediction is performed, the first time first predict 3 layers of mask then select the one with the highest score, use this mask as a mask prompt, and with point prompt, box_prompt together Throw into prompt_encoder to get new sparse_embeddings, dense_embeddings, and then do the second mask prediction, this time only predict a mask. it is equivalent to get the rough mask first, and then refine it. Finally, after post-processing nms and other predicted masks of the same size as the original image, an object corresponds to a mask, multiple masks will be stacked up to get all the predictions predictions of this image.
## predicted masks, three level
mask_predictions, scores = sam.mask_decoder(
image_embeddings=image_embedding.to(device),
image_pe=sam.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
)
# Choose the model's best mask
mask_input = mask_predictions[:, (scores),...].unsqueeze(1)
with torch.no_grad():
sparse_embeddings, dense_embeddings = sam.prompt_encoder(
points=points,
boxes=box_torch,
masks=mask_input,
)
## predict a better mask, only one mask
mask_predictions, scores = sam.mask_decoder(
image_embeddings=image_embedding.to(device),
image_pe=sam.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
best_mask = sam.postprocess_masks(mask_predictions, input_size, original_size)
(best_mask)
Loss calculation
The code uses BCELoss plus DiceLoss for the loss, which requires the same shape for gt and pred, both of the form BxCxHxW, and pred is the value after sigmoid.
Therefore it is necessary to convert gt to one-hot form, i.e., (batch_size, 1, h, w) to (batch_size, c, h, w), where c is the number of categories in gt_classes, i.e., how many instance categories are in the picture.
def mask2one_hot(label, gt_classes):
"""
label: label image # (batch_size, 1, h, w)
num_classes: Number of classification categories
"""
current_label = (1) # (batch_size, 1, h, w) ---> (batch_size, h, w)
batch_size, h, w = current_label.shape[0], current_label.shape[1], current_label.shape[2]
one_hots = []
for cls in gt_classes:
if isinstance(cls, ):
cls = ()
tmplate = (batch_size, h, w) # (batch_size, h, w)
tmplate[current_label == cls] = 1
tmplate = (batch_size, 1, h, w) # (batch_size, h, w) --> (batch_size, 1, h, w)
one_hots.append(tmplate)
onehot = (one_hots, dim=1)
return onehot
In addition, BCE accepts pred values in logit form, so predictions need to be processed with sigmoid, and the subsequent loss calculation corresponds to the following code
gts = torch.from_numpy(gts).unsqueeze(1) ## BxHxW ---> Bx1xHxW
gts_onehot = mask2one_hot(gts, valid_classes)
gts_onehot = gts_onehot.to(device)
predictions = (predictions)
# #loss = seg_loss(predictions, gts_onehot)
loss = BCEseg(predictions, gts_onehot)
loss_dice = soft_dice_loss(predictions, gts_onehot, smooth = 1e-5, activation='none')
loss = loss + loss_dice
Weight saving
The optimizer is AdamW by default, the scheduler is CosineAnnealingLR, these can be modified by yourself. The weights saved at the end are only the ones with the smallest current loss, and only the weights of the decoder part are saved, which can be modified as needed
if epoch_loss < best_loss:
best_loss = epoch_loss
mask_decoder_weighs = sam.mask_decoder.state_dict()
mask_decoder_weighs = {f"mask_decoder.{k}": v for k,v in mask_decoder_weighs.items() }
(mask_decoder_weighs, (save_dir, f'sam_decoder_fintune_{str(epoch+1)}_pointbox_monai.pth'))
print("Saving weights, epoch: ", epoch+1)
Full series complete, thanks for reading...