atma-inc__blog

atma株式会社の公式ブログです。

DRFとNuxtを使って画像分類(機械学習)をする①

はじめに

インターンしている小林です.この記事では,DRFDjangoのいい感じのフレームワーク)を使って,APIを作るまで行います.記事は二編構成とし,一編はDRFによるAPI作成,二編はNuxtを用いてユーザが実際に入力することを想定してフロント作成します.具体的には,PyTorchのresnetを用いて,入力フォームから受け付けられた画像を推論して上位10位までの結果を表示させます.一編では,詳細な機械学習アルゴリズムは説明せずに,APIを作る工程に重きを向けます.読者の対象はDRFを初めたての人が対象であり,機械学習の画像処理をある程度把握している人が対象となります.

構築したAPIは以下のような感じになります.

f:id:atma_inc:20200327163520g:plain
今回作るAPI画面

結果で返しているのはresnet-18に入力した画像を推論させ,確率値が高い上位10個を表示させています.用いてるモデルはImageNetの学習済みモデルです.

DRFについて

Django」は Python で Webアプリケーションを作成するためフレームワークですが、「Django REST Framework」という Django のためのパッケージを使うことで、RESTful な API バックエンドを簡単に構築することができます。実際の現場では、SPA(シングルページアプリケーション)やスマホアプリのバックエンドとしてよく利用されています(引用:現場で使える Django REST Framework の教科書 (Django の教科書シリーズ))。とのことですが,基本的にはDjangoで足りない所を補ってやりたいというのが,これを使っている理由です.ただし,色々な機能があるため少し重たいファイルであることはデメリットですが,それを超える良い機能が複数あるので慣れると使いやすいものかと思われます.

目次

  1. 環境構築
  2. DRFのモジュール作成
  3. resnetの作成
  4. APIを使って推論
  5. まとめ
  6. 次回

環境構築

環境構築は以下を基本として構築しました.

https://qiita.com/michio-k/items/371881a6b8ecfa768606

ファイル構成は以下のようになります.

home
|- backend
| |- core(Djangoのプロジェクトが入る)
| |- app(APIを作成)
| |- Dockerfile
| |- requirements.txt
|- front
| |- nuxt (フロントのプロジェクト)
| |- Dockerfile
|- .gitignore
|- docker-compose.yml
|- README.md

今回操作するのは上記のbackendの方となります.二編目でfrontの方をいじっていきます.画像のAPIを作成する上で,いくつかインストールする必要があるモジュールがあるので,記していきます.

FROM python:3.7
ENV PYTHONUNBUFFERED 1
RUN mkdir /code
WORKDIR /code
RUN apt-get update && apt-get install -y \
    libblas-dev \
    liblapack-dev\
    libatlas-base-dev \
    libsm6 \
    libxext6 \
    libxrender-dev

ADD requirements.txt /code/
RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt

インストールするpythonモジュールです.ここでは,django-jsonfieldを使って辞書のデータを受け付けるようにします.実務的になると,PostgresやMySQLを使う方がいいと思われるので,そちらをデータベースとして参照する方が望ましいです.今回は簡易的なものなので,これを使わず,sqlite(デフォルトの設定)でやっていきます.これらのことは後に後述します.

* 補足として,DjangoでPostgresやMySQLなどでデータベースを使用したいときは,以下のサイトを参考にしてください. qiita.com

*

画像分類の処理では今回はPyTorchを使っていきます.pillowはDRFDjango)が画像を読み込む時に必要となるのでインストールしておきます.

# backend/requirements.txt
#Django
django
djangorestframework
django-filter
django-cors-headers
django-jsonfield
#Extra
numpy
pillow
opencv-python
torch
torchvision

次にdocker-compose.ymlを記述します.以下の通りになります.

version: '3'

services:

  #front:
    #container_name: front
    #build: ./front
    #tty: true
    #ports:
     # - '3000:3000'
    #volumes:
      #- ./front/:/usr/src/app
    # command: [sh, -c, "cd nuxt/ && npm run dev"]

  backend:
    container_name: backend
    build: ./backend
    tty: true
    ports:
      - '8000:8000'
    volumes:
      - ./backend:/code
    # command: python manage.py runserver 0.0.0.0:8000

これにて環境構築が終わりです.補足ですが,Dockerを使わなくてもrequirements.txtに記述してあるモジュールをインストールしている環境であるならば,これからやることはできます.また,下のコードで先程構築した環境に入ることができます.

docker-compose build #Dockerfileの環境を立ち上げる
docker-compose up -d #起動
docker exec -it backend bash #containerの中に入る

DRFのモジュール作成

DRFのプロジェクトとアプリの作成

これからDjangoのプロジェクトを作成していきます.この方法は通常のDjangoのやり方と変わりません.

django-admin startproject core .
python manage.py startapp app

にてファイルを作成します.次にcore内のsettings.pyに今回作ったファイルとDRFを読み込ませます.また,各種必要なものを記述しておきます.MEDIAは画像の保存先を指定するために必要となりますので追記してください.

#python:core/settings.py
ALLOWED_HOSTS = ["localhost"]

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',
    "rest_framework", #add
    "app", #add
]

#add
MEDIA_URL = "/media/"
MEDIA_ROOT = os.path.join(BASE_DIR,"media")

resnetのファイル

この記事での推論はresnetを使用します.また,簡易的なものであるため,自前の学習済みモデルを使用せずネット上に公開されているFinetuned-modelを利用します(ImageNetです).そのため,以下にしているファイルを事前にダウンロードしてください.

以上で初期に用意するファイル一式の準備はできました.上記のファイルとresnet用に用意するファイルは以下のように作成してください.

backend
├── Dockerfile
├── app
│   ├── __init__.py
│   ├── admin.py
│   ├── apps.py
│   ├── models.py
│   ├── resnet #add
│   │   ├── config
│   │   │   ├── imagenet_class_index.json
│   │   │   └── resnet18-5c106cde.pth
│   │   ├── model.py
│   │   └── predict.py
│   ├── tests.py
│   └── views.py
├── core
│   ├── __init__.py
│   ├── asgi.py
│   ├── settings.py
│   ├── urls.py
│   └── wsgi.py
├── manage.py
└── requirements.txt

上記のようなファイル構成になっていると大丈夫です!

modelsの作成

今回APIとして必要になるのは以下の通りです.

  • 入力:入力した画像の名前と入力する画像
  • 出力:確率値が高い上位10までのラベル一覧と確率値

となります.そのため受け付けるフィールドは三つとなります.

#python:app/models.py
from django.db import models
import jsonfield
# from django.contrib.postgres.fields import JSONField  Jsonを受け付ける

class ImageModel(models.Model):
    name = models.CharField(max_length = 128,null=True,default="unknown")
    image = models.ImageField(upload_to="media")
    predict = jsonfield.JSONField()

ここではjsonfieldというモジュールを使って,出力するための値を受け付けます.出力する値はjson形式にしたいのですが,Django内で提供されているJSONFieldはデータベースがpostgresやMySQLなどに対応しており,設定を変更しなければ使用できません.これはDjangoのデフォルトのデータベースがsqliteであり,対応していないためエラーが起こります.この問題を解決するために,今回は設定を省略し,jsonfieldというもので簡単にsqliteが受け付けられるようにしました.

また,今回作成したデータベースを登録するためにadminの内容を変更します.以下のように記述してください.

#app/admin.py
from django.contrib import admin
from .models import ImageModel

@admin.register(ImageModel)
class ImageModel(admin.ModelAdmin):
    pass

serializerの作成

続いてserializerの作成です.serializerはDRF特有のものであり,通常のDjangoにはありません.詳細は記事は以下のものが参考になるかと思いますので,乗せて起きます.

やっていることは,入力されたデータの値がModelの中身で定義した型と一緒なのか?ということをやったり,Json形式で入力されたものをPythonで読み込めるようにしたりとそんなことをやっています.

appの直下にserializers.pyのファイルを作成し,以下のように記述します.

#app/serializers.py
from rest_framework import serializers
from .models import ImageModel

class ImageSerializer(serializers.ModelSerializer):
    class Meta:
        model = ImageModel
        fields = ("id","name","image","predict")
        read_only_fields = ('predict',"id")

上記のように今回は書きました.入力としてはnameとimageのみなので,入力に必要ないものは外しています.また,上記のMetaに関する情報はhttps://teratail.com/questions/87695 が参考になるかと思いますので適宜参考にしてみてください.

viewsの作成

今回はDRFのビューはクラスベースビューを用いて,ModelViewSetを使って見ました.また,postを受け付けるコードをactionで対応するようにしました(これはdef post()メソッドを用いてもらっても大丈夫です).

#app/views.py
from rest_framework import viewsets
from rest_framework.response import Response
from rest_framework.decorators import action
from rest_framework import status
from .models import ImageModel
from .serializers import ImageSerializer
from .resnet.predict import predict #resnetの予測

class ImageViewSet(viewsets.ModelViewSet):
    queryset = ImageModel.objects.all() 
    serializer_class = ImageSerializer #check ユーザのどんなクエリを受け付けるか

    @action(detail=False,methods=["post"])
    def classification(self,request):
        serializer = self.serializer_class(data = request.data)
        serializer.is_valid(raise_exception=True)
        img = request.data["image"]
        name = request.data["name"]
        res = predict(img)
        # 保存
        item = ImageModel(name=name, image=img,predict = res)
        item.save()
        return Response(res, status=status.HTTP_200_OK)
  

serializer_classで受け付ける入力フォームを決めています.

@action(detail=False,methods=["post"])
def classification(self,request):
    serializer = self.serializer_class(data = request.data)
    serializer.is_valid(raise_exception=True)
    img = request.data["image"]
    name = request.data["name"]
    res = predict(img)
    # 保存
    item = ImageModel(name=name, image=img,predict = res)
    item.save()
    return Response(res, status=status.HTTP_200_OK)
    

serializer = self.serializer_class(data = request.data)は入力されたデータが正しいかどうかを検証するためにいれています.極端な話,PDFのファイルが入力された時,エラーを出力してくれます.serializer.is_valid(raise_exception=True)を記述するとこの時点でエラーのデータがあるとエラーの文章で値が返されます.また,記述方法として,

if serializer.is_valid():
    img = request.data["image"]
    name = request.data["name"]
    return Response(serializer.data,status=status.HTTP_200_OK)
else:
    return Response(serializer.errors,status=status.HTTP_400_BAD_REQUEST)

があります.これは,if文章でTrueかFalseを処理して中身を実行するかどうかを判断しています.しかし,この書き方は若干情緒でもあるのでこれを省略しserializer.is_valid(raise_exception=True)だけを記述することで,上記のことと全く同じようにしてくれます.

ifの中身は,predict(img)でdictの結果を返してresで受け付けています.このresは先ほどの確率値の上位10個が入っている値の一覧が格納されています.これらのデータをItemModel()に入れ,保存しています.

urlの繋ぎこみ

繋ぎこみをします.app以下にurls.pyのファイルを作成して以下のように記述してください.

#app/urls.py
from rest_framework import routers
from .views import ImageViewSet

router = routers.DefaultRouter()
router.register(r"^image",ImageViewSet)
#core/urls.py
from django.contrib import admin
from django.urls import path,include
from django.conf import settings
from django.conf.urls.static import static
from app.urls import router as router
urlpatterns = [
    path('admin/', admin.site.urls),
    path("",include(router.urls))
]
if settings.DEBUG:
    urlpatterns += static(settings.MEDIA_URL,document_root = settings.MEDIA_ROOT)

繋ぎこみの際にファイル名を注意してください.これでlocalhost:8000/image/classificationとURLを入れた時に,APIをPOSTできるようになります.

これでDRFで記述すべきことは終わりました.次に,resnetの方を記述していきます.

resnetの作成

resnetのディレクトリの中のmodel.pyを作成します.以下のように記述してください.

#app/resnet/model.py
import torch
from torchvision import models

def resnet_model():
    MODEL_PATH = "./app/resnet/config/resnet18-5c106cde.pth"
    model = models.resnet18(pretrained=False)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
    return model

今回は簡易的に作っているため,MODEL_PATHをこんな風にPathを書くことはナンセンスだと思うので注意してください(笑).また,model.eval()を忘れないでください(これを書くの忘れて何時間も悩んだのは裏の話).

次にpredict.pyを作成します.以下のように記述してください.

#app/resnet/predict.py
from PIL import Image
import json
import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from .model import resnet_model

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
softmax = torch.nn.Softmax(dim=1)
model = resnet_model()
def predict(img):
    # json
    with open("./app/resnet/config/imagenet_class_index.json", 'r') as f:
        image_dict = json.load(f)
    # img
    img = Image.open(img)
    img = img.convert('RGB')
    img = np.array(img)
    img = preprocess(transforms.ToPILImage()(img)).unsqueeze(0)

    # prediction
    predict = model(img).data
    prob = softmax(predict)[0].tolist()
    best_ten = np.argsort(prob)[::-1][:10]

    response = []
    for i,rank in enumerate(best_ten,1):
        label = image_dict[str(rank)][1]
        response.append({"rank":i,"prob":prob[rank],"label":label})

    return response

入力された画像はImage.open(img)にて読み込みます.cv.imread(img)で読み込むことはできないので注意してください.また,

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

この部分はImageNetの学習方法と同じようにしているので,これがなければ出力で欲しいラベルが返って来なくなります pytorch.org

returnで返しているのは配列であり,中身は辞書型になっています.probの確率値はtolist()でリスト型にしています.これはデータベースにデータが保存されるときに,データの型がnumpyであるとエラーの原因になるためです(これはちょっとはまり所でした).そこで,tolist()で通常の数値型に変更し,エラーの原因を未然に防ぎます.参照したのは以下のページです.

stackoverflow.com

APIを使って推論

ここまででAPIは作れたので,実際の画面で確認していきます.

python manage.py makemigrations
python manage.py migrate

をしてください.これはDjangoの定型文みたいなものなので,そうなんだーみたいな感じでやってください(これはデータベースを作ってくれたりしている).また,Modelsの中身を変更したり,追加する場合は上記のコードをもう一度入力してください.すると,更新されます.

python manage.py runserver 0.0.0.0:8000

を実行しlocalhost:8000/imageのurlを検索すると以下のような画面が出てくると思います.

f:id:atma_inc:20200327171142p:plain
APIの画面

ここに保存されたデータの一覧が出力されるようになります.localhost:8000/image/classificationのurlを検索するとAPIをPOSTできる画面が出力されるのでやって見てください.

まとめ

今回は画像を入力として受け付けて,上位10個の確率値とラベルを出力するAPIを作成しました.モデルはresnetを使用し,ImageNetの学習済みモデルで値を出力させました.これらは適宜自分のモデルの差し替えが可能なので他のものでも試していきたいと思います.

次回

次はNuxtを利用して,front側を作成していきたいと思います.入力フォームを作成し,実際にユーザが画像を投稿するようなイメージで作成します.その画像をAPIに投げ,値が返ってくるところまでを実装し,画面に表示させます.