Home > Study > Linux > Day7 : HARIBO_Mini_Project

Day7 : HARIBO_Mini_Project
Study Language

πŸš€ Model Improvement Strategy


κΈ°μ‘΄ CNN λͺ¨λΈμ— λ‹€μŒ μ „λž΅μ„ ν†΅ν•©ν•˜μ—¬ μ„±λŠ₯을 ν–₯μƒμ‹œν‚¨ ꡬ쑰λ₯Ό κ΅¬ν˜„ν•¨:

  1. 데이터 증강 적용: λ‹€μ–‘ν•œ 이미지 λ³€ν˜•μ„ 톡해 ν•™μŠ΅ 데이터 λ‹€μ–‘μ„± 확보
  2. μ „μ΄ν•™μŠ΅ λ„μž…: MobileNetV2의 μ‚¬μ „ν•™μŠ΅λœ νŠΉμ§• μΆ”μΆœκΈ° μ‚¬μš©
  3. Dropout 및 Dense λ ˆμ΄μ–΄ μΆ”κ°€: μ˜€λ²„ν”ΌνŒ… λ°©μ§€ 및 λͺ¨λΈ ν‘œν˜„λ ₯ ν–₯상
  4. EarlyStopping, ModelCheckpoint 적용: 과적합 λ°©μ§€ 및 졜적 λͺ¨λΈ μ €μž₯
  5. 데이터 증강 κ°•ν™”: νšŒμ „, 이동, ν™•λŒ€/μΆ•μ†Œ, λ°˜μ „ λ“± 볡합적 증강 적용

🍬 HARIBO_Dataset Preparation

  1. 5κ°€μ§€ ν•˜λ¦¬λ³΄ 저리 μ’…λ₯˜(bear, cola, egg, heart, ring)λ₯Ό 직접 μ΄¬μ˜ν•˜μ—¬ 이미지 데이터셋 생성
  2. λ‹€μ–‘ν•œ 각도·쑰λͺ…Β·λ°°κ²½μ—μ„œ μˆ˜μ§‘λœ 이미지 총 500μž₯ (각 ν΄λž˜μŠ€λ‹Ή 100μž₯ λ‚΄μ™Έ)
  3. ꡬ글 λ“œλΌμ΄λΈŒμ— μ—…λ‘œλ“œ ν›„, Google Colab ν™˜κ²½μ—μ„œ μ‹€μŠ΅μš©μœΌλ‘œ 연동

alt text


πŸ’‘ Code : CNN with Transfer Learning & Augmentation

import matplotlib.pyplot as plt
import numpy as np
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras import models, layers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# βœ… ꡬ글 λ“œλΌμ΄λΈŒ 마운트
from google.colab import drive
drive.mount('/content/drive')

# βœ… 경둜 μ„€μ •
dataset_path = '/content/drive/MyDrive/haribo_dataset'
model_save_path = '/content/drive/MyDrive/haribo_model.h5'

# βœ… 데이터 증강 μ„€μ •
datagen = ImageDataGenerator(
    rescale=1./255,              # ν”½μ…€ 값을 0~1 λ²”μœ„λ‘œ μ •κ·œν™”
    validation_split=0.2,        # 전체 데이터 쀑 20%λ₯Ό κ²€μ¦μš©μœΌλ‘œ μ‚¬μš©
    rotation_range=90,           # μ΅œλŒ€ Β±90도 λ²”μœ„ λ‚΄μ—μ„œ λ¬΄μž‘μœ„ νšŒμ „
    width_shift_range=0.1,       # 전체 λ„ˆλΉ„μ˜ 10%만큼 쒌우 이동
    height_shift_range=0.1,      # 전체 λ†’μ΄μ˜ 10%만큼 μƒν•˜ 이동
    shear_range=0.1,             # 전단 λ³€ν™˜ (이미지λ₯Ό κΈ°μšΈμ΄λŠ” 효과)
    zoom_range=0.1,              # 10% λ²”μœ„ λ‚΄ λ¬΄μž‘μœ„ ν™•λŒ€/μΆ•μ†Œ
    horizontal_flip=True,        # 이미지λ₯Ό 쒌우둜 λ¬΄μž‘μœ„ λ°˜μ „
    fill_mode='nearest'          # λ³€ν™˜ ν›„ 생긴 빈 μ˜μ—­μ„ κ°€μž₯ κ°€κΉŒμš΄ ν”½μ…€λ‘œ 채움
)

# βœ… 데이터 λ‘œλ”©
train_generator = datagen.flow_from_directory(
    dataset_path,
    target_size=(96, 96),
    batch_size=32,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

val_generator = datagen.flow_from_directory(
    dataset_path,
    target_size=(96, 96),
    batch_size=32,
    class_mode='categorical',
    subset='validation',
    shuffle=True
)

# βœ… 클래슀 이름 μžλ™ μΆ”μΆœ
class_names = list(train_generator.class_indices.keys())
print("클래슀 인덱슀:", train_generator.class_indices)

# βœ… MobileNetV2 기반 λͺ¨λΈ ꡬ성
base_model = MobileNetV2(input_shape=(96, 96, 3), include_top=False, weights='imagenet')
base_model.trainable = False

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(len(class_names), activation='softmax')  # 클래슀 수 μžλ™ 반영
])

model.compile(optimizer=Adam(learning_rate=1e-4),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# βœ… 콜백 μ„€μ •
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
checkpoint = ModelCheckpoint('best_model.h5', save_best_only=True)

# βœ… ν•™μŠ΅ μ‹€ν–‰
history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=50,
    callbacks=[early_stop, checkpoint],
    verbose=2
)

# βœ… κ²°κ³Ό μ‹œκ°ν™”
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

# βœ… ν•™μŠ΅ 이미지 μ˜ˆμ‹œ
x_batch, y_batch = next(train_generator)
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([]); plt.yticks([]); plt.grid(False)
    plt.imshow(x_batch[i])
    label_idx = np.argmax(y_batch[i])
    plt.xlabel(class_names[label_idx])
plt.tight_layout()
plt.show()

# βœ… λͺ¨λΈ μ €μž₯ (.h5 파일)
model.save(model_save_path)
print(f"λͺ¨λΈμ΄ μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€: {model_save_path}")

βœ… Result : ν•™μŠ΅ κ²°κ³Ό μ‹œκ°ν™” 및 예츑 확인

alt text
alt text
alt text
alt text


πŸ” Summary

  • MobileNetV2λ₯Ό 기반으둜 ν•œ μ „μ΄ν•™μŠ΅ λͺ¨λΈμ΄ 적은 λ°μ΄ν„°μ…‹μ—μ„œλ„ 쒋은 μ„±λŠ₯을 λ³΄μž„
  • μ‹€μ‹œκ°„ 예츑 ν™˜κ²½μ—λ„ μ΅œμ ν™”λœ λͺ¨λΈ ꡬ쑰둜 μ „ν™˜ κ°€λŠ₯ (On-Device AI 적용 κ°€λŠ₯)

πŸ’» Real-Time Inference Setup on Terminal


πŸ“ 1. 디렉토리 ꡬ성

mkdir haribo_cam_classifier
cd haribo_cam_classifier

🐍 2. κ°€μƒν™˜κ²½ 생성 및 νŒ¨ν‚€μ§€ μ„€μΉ˜

python3 -m venv venv
source venv/bin/activate

pip install tensorflow opencv-python-headless numpy

πŸ“₯ 3. ν•™μŠ΅ν•œ λͺ¨λΈ(.h5)을 Google Driveμ—μ„œ λ‹€μš΄λ‘œλ“œν•˜μ—¬ 볡사

haribo_model.h5 νŒŒμΌμ„ Google Driveμ—μ„œ λ‹€μš΄λ°›μ•„ haribo_cam_classifier 디렉토리에 μœ„μΉ˜μ‹œν‚΄

alt text


πŸ–ΌοΈ 4. 클래슀 이름 파일 생성 (class_names.json)

["bear", "cola", "egg", "heart", "ring"]

πŸ’‘ 5. μ‹€μ‹œκ°„ λΆ„λ₯˜ μ½”λ“œ μž‘μ„± (predict_cam.py)

import cv2
import numpy as np
import tensorflow as tf
import json

# λͺ¨λΈκ³Ό 클래슀 이름 λ‘œλ“œ
model = tf.keras.models.load_model('haribo_model.h5')

with open('class_names.json', 'r') as f:
    class_names = json.load(f)

def preprocess(frame):
    img = cv2.resize(frame, (96, 96))
    img = img.astype('float32') / 255.0
    return np.expand_dims(img, axis=0)

cap = cv2.VideoCapture(2)
if not cap.isOpened():
    print("카메라λ₯Ό μ—΄ 수 μ—†μŠ΅λ‹ˆλ‹€.")
    exit()

print("저리 λΆ„λ₯˜ μ‹œμž‘! (Q ν‚€λ₯Ό λˆ„λ₯΄λ©΄ μ’…λ£Œ)")
while True:
    ret, frame = cap.read()
    if not ret:
        break

    input_img = preprocess(frame)
    pred = model.predict(input_img)
    label = class_names[np.argmax(pred)]

    # 예츑 κ²°κ³Ό 화면에 좜λ ₯
    cv2.putText(frame, f'Prediction: {label}', (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow('Haribo Classifier', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

🧩 6. OpenCV μ„€μΉ˜ (GUI 지원 포함)

pip install opencv-python

▢️ 7. μ‹€μ‹œκ°„ 예츑 μ‹€ν–‰

python3 predict_cam.py

βœ… 8. κ²°κ³Ό 정리

🍬 λͺ¨λΈ μ˜ˆμΈ‘μ„ μœ„ν•œ 5개 클래슀 ν•˜λ¦¬λ³΄ μƒ˜ν”Œ 전체 이미지
alt text

πŸ§ͺ 예츑 μ˜ˆμ‹œ: heart
alt text
alt text

πŸ§ͺ 예츑 μ˜ˆμ‹œ: ring
alt text
alt text

πŸ§ͺ 예츑 μ˜ˆμ‹œ: cola
alt text
alt text

πŸ§ͺ 예츑 μ˜ˆμ‹œ: egg
alt text
alt text

πŸ§ͺ 예츑 μ˜ˆμ‹œ: bear
alt text
alt text