帮助文档
  • 帮助文档
  • 注册账号
    • 个人设置
    • 团队设置
  • 开发者文档
    • 开发者工具
      • Python SDK
      • CLI
      • Open API
        • Dataset Operation
        • Data Operation
      • 示例演示
        • 模型训练
        • 数据挖掘
  • 产品使用文档
    • TensorBay
      • TensorBay 适用于算法研究的各个阶段
      • 概念说明
      • 快速入门 - 开发者版
      • 快速入门 - 团队版
      • 数据集准备
        • 新建数据集
        • 管理云服务上的数据
          • 阿里云oss用户授权RAM
        • 利用数据集筛选快速新建
        • 数据合并拆分及复制移动
        • 通过Fork使用公开数据集
      • 数据版本管理
        • 管理数据文件
        • 管理标注数据
        • 管理数据集信息
        • 管理版本信息
        • 管理数据集分支
        • 数据集活动
        • 数据集设置
      • Pharos 数据可视化
        • 获取 Pharos
        • Pharos 使用指南
      • 团队协作
        • 创建团队空间
        • 资料/成员管理
        • 数据集管理
        • 团队活动日志
      • Action 流程自动化
        • 新建工作流
        • 运行工作流
        • 管理工作流
        • YAML语法说明
        • Crontab语法说明
        • 自动化配置
      • 如何与Pipeline集成
    • Open Datasets
      • 数据集概念
      • 查找公开数据集
      • 在线预览数据及标签分布
      • 在线使用并管理数据集
      • 下载公开数据集
      • 找不到想要的数据集?
  • APPs
    • GroundTruth Tools
      • 图片标注
      • 语音分类
    • TeraGood Service
      • 需求方使用手册
      • 运营方使用手册
      • 标注方使用手册
    • Sextant
      • 新建评估
      • 自定义Metrics
      • 参与评估
      • 查看评估结果
  • 更新日志
  • 格物钛官网
由 GitBook 提供支持
在本页

这有帮助吗?

  1. 开发者文档
  2. 开发者工具
  3. 示例演示

模型训练

本示例将在线训练一个基于MNIST数据集的分类模型, 帮助您快速了解格物钛数据平台

上一页示例演示下一页数据挖掘

最后更新于3年前

这有帮助吗?

1. 准备数据

a. 进入公开数据集,搜索并fork MNIST 数据集到自己的空间中

2. 配置密钥

a. 点击导航栏中的开发者工具,新建accesskey并复制

b. 进入之前fork的MNIST数据集

c. 进入设置->自动化配置->新建密钥

d. 新建密钥,密钥名为accesskey, 密钥值为a步骤复制的accesskey

3. 创建训练评估工作流

a. 点击自动化->新建工作流

b. 使用下列yaml文件新建工作流

tasks:
  #  新建一个数据集用于保存训练完的模型
  createModelDataset:
    script:
      # 本task运行所依赖性的镜像名, 支持公开的镜像仓库
      image: hub.graviti.cn/algorithm/mnist:0.7

      # 使用python3运行下面source后对应的脚本
      command: [python3]

      source: |
        import logging
        import os
        from tensorbay import GAS
        logging.basicConfig(level=logging.INFO)
        dataset_name = "MNIST_MODEL"
        ACCESS_KEY = os.environ.get("secret.accesskey")
        gas = GAS(ACCESS_KEY)
        try:
            gas.create_dataset(dataset_name)
            logging.info(f"Created dataset {dataset_name} Successfully")
        except:
            logging.info(f"{dataset_name} aleady exists.")
  #  训练一个简单的mnist模型, 并使用输出模型文件进行预测
  training:
    dependencies:
      - createModelDataset
    script:
      image: hub.graviti.cn/algorithm/mnist:0.7
      command: [python3]
      source: |
        import logging
        import os

        import torch
        from PIL import Image
        from tensorbay import GAS
        from tensorbay.dataset import Dataset as TensorBayDataset
        from tensorbay.dataset.data import Data
        from torch import nn
        from torch.utils.data import DataLoader, Dataset
        from torchvision import transforms

        logging.basicConfig(level=logging.INFO)


        # 搭建网络架构
        class NeuralNetwork(nn.Module):
            def __init__(self):
                super(NeuralNetwork, self).__init__()
                self.flatten = nn.Flatten()
                self.linear_relu_stack = nn.Sequential(
                    nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10)
                )

            def forward(self, x):
                x = self.flatten(x)
                logits = self.linear_relu_stack(x)
                return logits


        # 从Graviti平台中读取数据集
        class MNISTSegment(Dataset):
            """class for wrapping a MNIST segment."""

            def __init__(self, dataset, segment_name, transform):
                super().__init__()
                self.dataset = dataset
                self.segment = self.dataset[segment_name]
                self.category_to_index = self.dataset.catalog.classification.get_category_to_index()
                self.transform = transform

            def __len__(self):
                return len(self.segment)

            def __getitem__(self, idx):
                data = self.segment[idx]
                with data.open() as fp:
                    image_tensor = self.transform(Image.open(fp))

                return image_tensor, self.category_to_index[data.label.classification.category]


        def train(dataloader, model, loss_fn, optimizer):
            size = len(dataloader.dataset)
            model.train()
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = model(X)
                loss = loss_fn(pred, y)

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if batch % 100 == 0:
                    loss, current = loss.item(), batch * len(X)
                    logging.info(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


        def test(dataloader, model, loss_fn):
            size = len(dataloader.dataset)
            num_batches = len(dataloader)
            model.eval()
            test_loss, correct = 0, 0
            with torch.no_grad():
                for X, y in dataloader:
                    X, y = X.to(device), y.to(device)
                    pred = model(X)
                    test_loss += loss_fn(pred, y).item()
                    correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            test_loss /= num_batches
            correct /= size
            logging.info(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

        if __name__ == "__main__":
            BTACH_SIZE = 64
            EPOCHS = 3
            ACCESS_KEY = os.environ.get("secret.accesskey")
            gas = GAS(ACCESS_KEY)
            mnist_dataset = TensorBayDataset("MNIST", gas)
            mnist_dataset.enable_cache()
            to_tensor = transforms.ToTensor()
            normalization = transforms.Normalize(mean=[0.485], std=[0.229])
            my_transforms = transforms.Compose([to_tensor, normalization])

            train_segment = MNISTSegment(mnist_dataset, segment_name="train", transform=my_transforms)
            test_segment = MNISTSegment(mnist_dataset, segment_name="test", transform=my_transforms)
            train_dataloader = DataLoader(train_segment, batch_size=BTACH_SIZE, num_workers=10)
            test_dataloader = DataLoader(test_segment, batch_size=BTACH_SIZE, num_workers=10)

            device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info(f"Using {device} device")

            model = NeuralNetwork().to(device)
            logging.info(model)
            loss_fn = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

            for epoch in range(EPOCHS):
                logging.info(f"Epoch {epoch+1}\n-------------------------------")
                train(train_dataloader, model, loss_fn, optimizer)
                test(test_dataloader, model, loss_fn)
            logging.info("Done!")

            torch.save(model.state_dict(), "model.pth")
            logging.info("Saved PyTorch Model State to model.pth")

            # 上传模型文件
            model_dataset = TensorBayDataset("MNIST_MODEL")
            segment = model_dataset.create_segment("model")
            segment.append(Data("./model.pth"))
            dataset_client = gas.upload_dataset(model_dataset)
            dataset_client.commit("upload mnist model file")
            logging.info("Uploaded model!")
  evaluate:
    dependencies:
      - training
    script:
      image: hub.graviti.cn/algorithm/mnist:0.7
      command: [python3]
      source: |
        import logging
        import os
        from concurrent.futures import ThreadPoolExecutor

        import torch
        from PIL import Image
        from tensorbay import GAS
        from tensorbay.dataset import Dataset as TensorBayDataset
        from tensorbay.dataset.data import Data
        from tensorbay.label import Classification
        from torch import nn
        from torchvision import transforms

        logging.basicConfig(level=logging.INFO)


        # 搭建网络结构
        class NeuralNetwork(nn.Module):
            def __init__(self):
                super(NeuralNetwork, self).__init__()
                self.flatten = nn.Flatten()
                self.linear_relu_stack = nn.Sequential(
                    nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10)
                )

            def forward(self, x):
                x = self.flatten(x)
                logits = self.linear_relu_stack(x)
                return logits



        def upload_label(model, segment_client, data):
            with data.open() as fp:
                image_tensor = my_transforms(Image.open(fp))
            pred = model(image_tensor)
            pred_data = Data(data.path)
            pred_data.label.classification = Classification(str(int(pred[0].argmax(0))))
            segment_client.upload_label(pred_data)


        if __name__ == "__main__":
            BTACH_SIZE = 64
            EPOCHS = 3
            ACCESS_KEY = os.environ.get("secret.accesskey")
            gas = GAS(ACCESS_KEY)
            to_tensor = transforms.ToTensor()
            normalization = transforms.Normalize(mean=[0.485], std=[0.229])
            my_transforms = transforms.Compose([to_tensor, normalization])
            model_dataset = TensorBayDataset("MNIST_MODEL", gas)
            data = model_dataset[0][0]
            with open(f"./model.pth", "wb") as fp:  # 本地存储数据的路径
                fp.write(data.open().read())
            model = NeuralNetwork()
            model.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
            logging.info(model)

            mnist_dataset = TensorBayDataset("MNIST", gas)
            mnist_dataset.enable_cache()
            mnist_dataset_client = gas.get_dataset("MNIST")
            mnist_dataset_client.create_branch("training")
            mnist_dataset_client.create_draft("update label")

            for segment in mnist_dataset:
                segment_client = mnist_dataset_client.get_segment(segment.name)
                with ThreadPoolExecutor(10) as executor:
                    for data in segment:
                        executor.submit(upload_label, model, segment_client, data)
            mnist_dataset_client.commit("update label")

4. 开始训练

a. 进入自动化,选择创建好的工作流,点击运行

5. 查看结果

a. 查看工作流运行日志

b. 进入名为training的分支查看模型预测效果,查看diff可视化及分布变化

c. 查看保存的训练模型