새소식

Project

모바일 로봇 프로젝트 2 - SAM2

  • -

https://github.com/khw11044/SAM2_streaming

 

GitHub - khw11044/SAM2_streaming

Contribute to khw11044/SAM2_streaming development by creating an account on GitHub.

github.com

 

위 깃헙주소를 그대로 따라하면 sam2를 통해 바운딩 박스를 마우스로 지정하고 segmentation tracking을 진행할 수 있다. 

 

이번에는 바운딩 박스를 마우스로 지정하는게 아니라 yolo object detection을 이용해서 바운딩 박스를 자동으로 지정하고 사람을 segmentation & tacking 하자 

 

먼저 위 깃헙 rep에서 demo가 실행 될 수 있는 최소한의 폴더와 파일 구성만 남겨두고 sam2가 실행 될 수 있는 가상환경을 잘 준비해 준다. sam2 폴다랑 sam2_configs 폴더 및 모델파일들 정도는 있어야겠다. 

 

그리고 아래 코드를 실행하면 

 

import torch
import numpy as np
import cv2
import requests
from ultralytics import YOLO
from sam2.build_sam import build_sam2_camera_predictor

# 모바일로봇 서버의 스트리밍 URL
url = "http://xx.xx.0.xx:5000/video_feed"  # Flask 서버의 /video_feed URL

# 스트리밍 연결
stream = requests.get(url, stream=True)
if stream.status_code != 200:
    print(f"Streaming 연결 실패: 상태 코드 {stream.status_code}")
    exit()

# YOLO 모델 로드
yolo_model = YOLO("./models/yolo11n.pt")  # YOLO 모델 경로

# SAM2 모델 로드
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
model_cfg = "sam2_hiera_s.yaml"
predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint)

# PyTorch 설정
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


# 초기화 변수
if_init = False
largest_bbox = None  # 가장 큰 바운딩 박스를 저장
enter_pressed = False  # SAM2 초기화 여부
byte_data = b""  # 스트리밍 데이터를 저장할 바이트 버퍼

frame_num = 0
for chunk in stream.iter_content(chunk_size=1024):  # 1KB 단위로 데이터 읽기
    byte_data += chunk
    a = byte_data.find(b'\xff\xd8')  # JPEG 시작 부분
    b = byte_data.find(b'\xff\xd9')  # JPEG 끝 부분
    if a != -1 and b != -1:  # JPEG 이미지의 시작과 끝이 존재할 때
        jpg = byte_data[a:b+2]  # JPEG 이미지 추출
        byte_data = byte_data[b+2:]  # 읽은 데이터 버퍼에서 제거

        # JPEG 데이터를 OpenCV 이미지로 디코딩
        frame = cv2.imdecode(np.frombuffer(jpg, dtype=np.uint8), cv2.IMREAD_COLOR)
        height, width = frame.shape[:2]

        if largest_bbox is None:
            # YOLO로 사람 탐지
            results = yolo_model.track(source=frame, conf=0.5, classes=[0], stream=True, show=False, verbose=False)

            largest_area = 0
            for result in results:
                for box in result.boxes:
                    x1, y1, x2, y2 = map(int, box.xyxy[0])  # 바운딩 박스 좌표
                    area = (x2 - x1) * (y2 - y1)  # 바운딩 박스 면적 계산
                    if area > largest_area:  # 가장 큰 영역 선택
                        largest_area = area
                        largest_bbox = (x1, y1, x2, y2)

            # 가장 큰 사람 바운딩 박스를 그리기
            if largest_bbox:
                cv2.rectangle(frame, (largest_bbox[0], largest_bbox[1]), 
                            (largest_bbox[2], largest_bbox[3]), (0, 255, 0), 2)

        # SAM2를 사용하여 세그먼테이션 및 트래킹
        if largest_bbox and not if_init:
            if_init = True
            predictor.load_first_frame(frame)
            ann_frame_idx = 0
            ann_obj_id = 1
            bbox = np.array([[largest_bbox[0], largest_bbox[1]], 
                             [largest_bbox[2], largest_bbox[3]]], dtype=np.float32)
            _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
                frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
            )
            enter_pressed = True
        elif enter_pressed:
            # 트래킹 단계
            out_obj_ids, out_mask_logits = predictor.track(frame)
            all_mask = np.zeros((height, width, 1), dtype=np.uint8)

            for i in range(0, len(out_obj_ids)):
                out_mask = (out_mask_logits[i] > 0.0).permute(1, 2, 0).cpu().numpy().astype(np.uint8) * 255
                all_mask = cv2.bitwise_or(all_mask, out_mask)

            all_mask = cv2.cvtColor(all_mask, cv2.COLOR_GRAY2RGB)
            frame = cv2.addWeighted(frame, 1, all_mask, 0.5, 0)

        # OpenCV로 이미지 표시
        cv2.imshow("Camera", frame)
        cv2.imwrite('./01save/{:03d}.jpg'.format(frame_num), frame)
        frame_num += 1
        # 'q'를 누르면 종료
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

# 리소스 해제
cv2.destroyAllWindows()

 

라즈베리파이 또는 로봇에서 보낸 영상에서 

 

가장 중점이 되는 사람을 yolo로 지정하고 segmentation을 수행한다. 

 

다른사람은 안하고 지정된 사람만 쫒아 다닌다. 

 

아래는 그 결과이다. 

 

 

 

 

https://youtu.be/P7yQVQLEvog?si=FGM4Wq1NjqaO7wch

 

바운딩 박스도 그리고 싶다면 

 

import torch
import numpy as np
import cv2
import requests
from ultralytics import YOLO
from sam2.build_sam import build_sam2_camera_predictor

# 모바일로봇 서버의 스트리밍 URL
url = "http://xxx.xxx.0.xx:5000/video_feed"  # Flask 서버의 /video_feed URL

# 스트리밍 연결
stream = requests.get(url, stream=True)
if stream.status_code != 200:
    print(f"Streaming 연결 실패: 상태 코드 {stream.status_code}")
    exit()

# YOLO 모델 로드
yolo_model = YOLO("./models/yolo11n.pt")  # YOLO 모델 경로

# SAM2 모델 로드
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
model_cfg = "sam2_hiera_s.yaml"
predictor = build_sam2_camera_predictor(model_cfg, sam2_checkpoint)

# PyTorch 설정
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


# 초기화 변수
if_init = False
largest_bbox = None  # 가장 큰 바운딩 박스를 저장
enter_pressed = False  # SAM2 초기화 여부
byte_data = b""  # 스트리밍 데이터를 저장할 바이트 버퍼

frame_num = 0
for chunk in stream.iter_content(chunk_size=1024):  # 1KB 단위로 데이터 읽기
    byte_data += chunk
    a = byte_data.find(b'\xff\xd8')  # JPEG 시작 부분
    b = byte_data.find(b'\xff\xd9')  # JPEG 끝 부분
    if a != -1 and b != -1:  # JPEG 이미지의 시작과 끝이 존재할 때
        jpg = byte_data[a:b+2]  # JPEG 이미지 추출
        byte_data = byte_data[b+2:]  # 읽은 데이터 버퍼에서 제거

        # JPEG 데이터를 OpenCV 이미지로 디코딩
        frame = cv2.imdecode(np.frombuffer(jpg, dtype=np.uint8), cv2.IMREAD_COLOR)
        height, width = frame.shape[:2]

        if largest_bbox is None:
            # YOLO로 사람 탐지
            results = yolo_model.track(source=frame, conf=0.5, classes=[0], stream=True, show=False, verbose=False)

            largest_area = 0
            for result in results:
                for box in result.boxes:
                    x1, y1, x2, y2 = map(int, box.xyxy[0])  # 바운딩 박스 좌표
                    area = (x2 - x1) * (y2 - y1)  # 바운딩 박스 면적 계산
                    if area > largest_area:  # 가장 큰 영역 선택
                        largest_area = area
                        largest_bbox = (x1, y1, x2, y2)

            # 가장 큰 사람 바운딩 박스를 그리기
            if largest_bbox:
                cv2.rectangle(frame, (largest_bbox[0], largest_bbox[1]), 
                            (largest_bbox[2], largest_bbox[3]), (0, 255, 0), 2)

        # SAM2를 사용하여 세그먼테이션 및 트래킹
        if largest_bbox and not if_init:
            if_init = True
            predictor.load_first_frame(frame)
            ann_frame_idx = 0
            ann_obj_id = 1
            bbox = np.array([[largest_bbox[0], largest_bbox[1]], 
                             [largest_bbox[2], largest_bbox[3]]], dtype=np.float32)
            _, out_obj_ids, out_mask_logits = predictor.add_new_prompt(
                frame_idx=ann_frame_idx, obj_id=ann_obj_id, bbox=bbox
            )
            enter_pressed = True
        elif enter_pressed:
            # 트래킹 단계
            out_obj_ids, out_mask_logits = predictor.track(frame)
            all_mask = np.zeros((height, width, 1), dtype=np.uint8)

            for i in range(0, len(out_obj_ids)):
                out_mask = (out_mask_logits[i] > 0.0).permute(1, 2, 0).cpu().numpy().astype(np.uint8) * 255
                all_mask = cv2.bitwise_or(all_mask, out_mask)

            
            # 바운딩 박스 계산
            contours, _ = cv2.findContours(all_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                x, y, w, h = cv2.boundingRect(contour)  # 바운딩 박스 좌표 계산
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # 바운딩 박스 그리기
            
            
            all_mask = cv2.cvtColor(all_mask, cv2.COLOR_GRAY2RGB)
            frame = cv2.addWeighted(frame, 1, all_mask, 0.5, 0)

        # OpenCV로 이미지 표시
        cv2.imshow("Camera", frame)
        cv2.imwrite('./01save/{:03d}.jpg'.format(frame_num), frame)
        frame_num += 1
        # 'q'를 누르면 종료
        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

# 리소스 해제
cv2.destroyAllWindows()

 

 

Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.