模型训练

本示例将在线训练一个基于MNIST数据集的分类模型, 帮助您快速了解格物钛数据平台

1. 准备数据

a. 进入公开数据集,搜索并fork MNIST 数据集到自己的空间中

2. 配置密钥

a. 点击导航栏中的开发者工具,新建accesskey并复制

b. 进入之前fork的MNIST数据集

c. 进入设置->自动化配置->新建密钥

d. 新建密钥,密钥名为accesskey, 密钥值为a步骤复制的accesskey

3. 创建训练评估工作流

a. 点击自动化->新建工作流

b. 使用下列yaml文件新建工作流

tasks:
  #  新建一个数据集用于保存训练完的模型
  createModelDataset:
    script:
      # 本task运行所依赖性的镜像名, 支持公开的镜像仓库
      image: hub.graviti.cn/algorithm/mnist:0.7

      # 使用python3运行下面source后对应的脚本
      command: [python3]

      source: |
        import logging
        import os
        from tensorbay import GAS
        logging.basicConfig(level=logging.INFO)
        dataset_name = "MNIST_MODEL"
        ACCESS_KEY = os.environ.get("secret.accesskey")
        gas = GAS(ACCESS_KEY)
        try:
            gas.create_dataset(dataset_name)
            logging.info(f"Created dataset {dataset_name} Successfully")
        except:
            logging.info(f"{dataset_name} aleady exists.")
  #  训练一个简单的mnist模型, 并使用输出模型文件进行预测
  training:
    dependencies:
      - createModelDataset
    script:
      image: hub.graviti.cn/algorithm/mnist:0.7
      command: [python3]
      source: |
        import logging
        import os

        import torch
        from PIL import Image
        from tensorbay import GAS
        from tensorbay.dataset import Dataset as TensorBayDataset
        from tensorbay.dataset.data import Data
        from torch import nn
        from torch.utils.data import DataLoader, Dataset
        from torchvision import transforms

        logging.basicConfig(level=logging.INFO)


        # 搭建网络架构
        class NeuralNetwork(nn.Module):
            def __init__(self):
                super(NeuralNetwork, self).__init__()
                self.flatten = nn.Flatten()
                self.linear_relu_stack = nn.Sequential(
                    nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10)
                )

            def forward(self, x):
                x = self.flatten(x)
                logits = self.linear_relu_stack(x)
                return logits


        # 从Graviti平台中读取数据集
        class MNISTSegment(Dataset):
            """class for wrapping a MNIST segment."""

            def __init__(self, dataset, segment_name, transform):
                super().__init__()
                self.dataset = dataset
                self.segment = self.dataset[segment_name]
                self.category_to_index = self.dataset.catalog.classification.get_category_to_index()
                self.transform = transform

            def __len__(self):
                return len(self.segment)

            def __getitem__(self, idx):
                data = self.segment[idx]
                with data.open() as fp:
                    image_tensor = self.transform(Image.open(fp))

                return image_tensor, self.category_to_index[data.label.classification.category]


        def train(dataloader, model, loss_fn, optimizer):
            size = len(dataloader.dataset)
            model.train()
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = model(X)
                loss = loss_fn(pred, y)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if batch % 100 == 0:
                    loss, current = loss.item(), batch * len(X)
                    logging.info(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


        def test(dataloader, model, loss_fn):
            size = len(dataloader.dataset)
            num_batches = len(dataloader)
            model.eval()
            test_loss, correct = 0, 0
            with torch.no_grad():
                for X, y in dataloader:
                    X, y = X.to(device), y.to(device)
                    pred = model(X)
                    test_loss += loss_fn(pred, y).item()
                    correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            test_loss /= num_batches
            correct /= size
            logging.info(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

        if __name__ == "__main__":
            BTACH_SIZE = 64
            EPOCHS = 3
            ACCESS_KEY = os.environ.get("secret.accesskey")
            gas = GAS(ACCESS_KEY)
            mnist_dataset = TensorBayDataset("MNIST", gas)
            mnist_dataset.enable_cache()
            to_tensor = transforms.ToTensor()
            normalization = transforms.Normalize(mean=[0.485], std=[0.229])
            my_transforms = transforms.Compose([to_tensor, normalization])

            train_segment = MNISTSegment(mnist_dataset, segment_name="train", transform=my_transforms)
            test_segment = MNISTSegment(mnist_dataset, segment_name="test", transform=my_transforms)
            train_dataloader = DataLoader(train_segment, batch_size=BTACH_SIZE, num_workers=10)
            test_dataloader = DataLoader(test_segment, batch_size=BTACH_SIZE, num_workers=10)

            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info(f"Using {device} device")

            model = NeuralNetwork().to(device)
            logging.info(model)
            loss_fn = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

            for epoch in range(EPOCHS):
                logging.info(f"Epoch {epoch+1}\n-------------------------------")
                train(train_dataloader, model, loss_fn, optimizer)
                test(test_dataloader, model, loss_fn)
            logging.info("Done!")

            torch.save(model.state_dict(), "model.pth")
            logging.info("Saved PyTorch Model State to model.pth")

            # 上传模型文件
            model_dataset = TensorBayDataset("MNIST_MODEL")
            segment = model_dataset.create_segment("model")
            segment.append(Data("./model.pth"))
            dataset_client = gas.upload_dataset(model_dataset)
            dataset_client.commit("upload mnist model file")
            logging.info("Uploaded model!")
  evaluate:
    dependencies:
      - training
    script:
      image: hub.graviti.cn/algorithm/mnist:0.7
      command: [python3]
      source: |
        import logging
        import os
        from concurrent.futures import ThreadPoolExecutor

        import torch
        from PIL import Image
        from tensorbay import GAS
        from tensorbay.dataset import Dataset as TensorBayDataset
        from tensorbay.dataset.data import Data
        from tensorbay.label import Classification
        from torch import nn
        from torchvision import transforms

        logging.basicConfig(level=logging.INFO)


        # 搭建网络结构
        class NeuralNetwork(nn.Module):
            def __init__(self):
                super(NeuralNetwork, self).__init__()
                self.flatten = nn.Flatten()
                self.linear_relu_stack = nn.Sequential(
                    nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10)
                )

            def forward(self, x):
                x = self.flatten(x)
                logits = self.linear_relu_stack(x)
                return logits



        def upload_label(model, segment_client, data):
            with data.open() as fp:
                image_tensor = my_transforms(Image.open(fp))
            pred = model(image_tensor)
            pred_data = Data(data.path)
            pred_data.label.classification = Classification(str(int(pred[0].argmax(0))))
            segment_client.upload_label(pred_data)


        if __name__ == "__main__":
            BTACH_SIZE = 64
            EPOCHS = 3
            ACCESS_KEY = os.environ.get("secret.accesskey")
            gas = GAS(ACCESS_KEY)
            to_tensor = transforms.ToTensor()
            normalization = transforms.Normalize(mean=[0.485], std=[0.229])
            my_transforms = transforms.Compose([to_tensor, normalization])
            model_dataset = TensorBayDataset("MNIST_MODEL", gas)
            data = model_dataset[0][0]
            with open(f"./model.pth", "wb") as fp:  # 本地存储数据的路径
                fp.write(data.open().read())
            model = NeuralNetwork()
            model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
            logging.info(model)

            mnist_dataset = TensorBayDataset("MNIST", gas)
            mnist_dataset.enable_cache()
            mnist_dataset_client = gas.get_dataset("MNIST")
            mnist_dataset_client.create_branch("training")
            mnist_dataset_client.create_draft("update label")

            for segment in mnist_dataset:
                segment_client = mnist_dataset_client.get_segment(segment.name)
                with ThreadPoolExecutor(10) as executor:
                    for data in segment:
                        executor.submit(upload_label, model, segment_client, data)
            mnist_dataset_client.commit("update label")

4. 开始训练

a. 进入自动化,选择创建好的工作流,点击运行

5. 查看结果

a. 查看工作流运行日志

b. 进入名为training的分支查看模型预测效果,查看diff可视化及分布变化

c. 查看保存的训练模型

最后更新于