ML

Core ML框架详细解析(十四) —— 使用Keras和Cor

2018-10-16  本文已影响580人  刀客传奇

版本记录

版本号 时间
V1.0 2018.10.16 星期二

前言

目前世界上科技界的所有大佬一致认为人工智能是下一代科技革命,苹果作为科技界的巨头,当然也会紧跟新的科技革命的步伐,其中ios API 就新出了一个框架Core ML。ML是Machine Learning的缩写,也就是机器学习,这正是现在很火的一个技术,它也是人工智能最核心的内容。感兴趣的可以看我写的下面几篇。
1. Core ML框架详细解析(一) —— Core ML基本概览
2. Core ML框架详细解析(二) —— 获取模型并集成到APP中
3. Core ML框架详细解析(三) —— 利用Vision和Core ML对图像进行分类
4. Core ML框架详细解析(四) —— 将训练模型转化为Core ML
5. Core ML框架详细解析(五) —— 一个Core ML简单示例(一)
6. Core ML框架详细解析(六) —— 一个Core ML简单示例(二)
7. Core ML框架详细解析(七) —— 减少Core ML应用程序的大小(一)
8. Core ML框架详细解析(八) —— 在用户设备上下载和编译模型(一)
9. Core ML框架详细解析(九) —— 用一系列输入进行预测(一)
10. Core ML框架详细解析(十) —— 集成自定义图层(一)
11. Core ML框架详细解析(十一) —— 创建自定义图层(一)
12. Core ML框架详细解析(十二) —— 用scikit-learn开始机器学习(一)
13. Core ML框架详细解析(十三) —— 使用Keras和Core ML开始机器学习(一)

Train the Model - 训练模型

1. Define Callbacks List - 定义回调列表

callbacksfit函数的可选参数,因此首先定义callbacks_list

输入以下代码,然后运行它。

callbacks_list = [
    keras.callbacks.ModelCheckpoint(
        filepath='best_model.{epoch:02d}-{val_loss:.2f}.h5',
        monitor='val_loss', save_best_only=True),
    keras.callbacks.EarlyStopping(monitor='acc', patience=1)
]

一个epoch是完整传递数据集中的所有小批量。

ModelCheckpoint回调监视验证丢失值,使用文件编号和文件名中的验证丢失将文件中的最低值保存。

EarlyStopping回调监控训练准确性:如果连续两个epochs未能改善,则训练提前停止。在我的实验中,这种情况从未发生过:如果acc在一个epoch内逐渐消失,它总会在下一个时代恢复。

2. Compile & Fit Model - 编译和拟合模型

除非您可以访问GPU,否则我建议您使用Malireddimodel_m进行此步骤,因为它的运行速度比Chollet的model_c快得多:在我的MacBook Pro上,76-106s / epoch与246-309s / epoch相比,或者大约15分钟vs 。 45分钟。

注意:如果在第一个epoch完成后notebook中没有出现.h5文件,请单击stop button以中断内核,单击save button,然后注销。在终端中,按Control-C停止服务器,然后重新运行docker run命令。将URL或令牌粘贴到浏览器或登录页面,导航到notebook,然后单击Not Trusted button按钮。选择此单元格,然后从菜单中选择Cell \ Run All Above

输入以下代码,然后运行它。这将花费很长时间,所以在等待时阅读Explanations部分。但是几分钟后检查Finder,以确保notebook正在保存.h5文件。

注意:此单元格显示多行函数调用的两种缩进类型,具体取决于您编写第一个参数的位置。如果它甚至被一个空格输出,那么这是一个语法错误。

model_m.compile(loss='categorical_crossentropy',
                optimizer='adam', metrics=['accuracy'])

# Hyper-parameters
batch_size = 200
epochs = 10

# Enable validation to use ModelCheckpoint and EarlyStopping callbacks.
model_m.fit(
    x_train, y_train, batch_size=batch_size, epochs=epochs,
    callbacks=callbacks_list, validation_data=(x_val, y_val), verbose=1)

Convolutional Neural Network: Explanations - 卷积神经网络:解释

您可以使用几乎任何ML方法来创建MNIST分类器,但本教程使用卷积神经网络(CNN),因为这是TensorFlowKeras的关键优势。

卷积神经网络假设输入是图像,并在三个维度上排列神经元:宽度,高度,深度。 CNN由卷积层组成,每个卷层检测训练图像的更高级特征:第一层可以训练滤波器以检测各种角度的短线或弧线;第二层训练滤波器以检测这些线的重要组合;最后一层的过滤器构建在前面的图层上以对图像进行分类。

每个卷积层在输入上传递一个小方块的kernel权重 - 1×1,3×35×5 ,计算内核下输入单元的加权和。 这是卷积过程。

每个神经元仅连接到前一层中的1个,9个或25个神经元,因此存在co-adapting的危险 - 过多地依赖于少数输入 - 这可能导致过度拟合。 因此,CNN包括poolingdropout层,以抵消co-adapting和过度拟合。 我在下面解释这些。

Sample Model - 样本模型

这是Malireddi的模型:

model_m = Sequential()
model_m.add(Conv2D(32, (5, 5), input_shape=input_shape, activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.5))
model_m.add(Conv2D(64, (3, 3), activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.2))
model_m.add(Conv2D(128, (1, 1), activation='relu'))
model_m.add(MaxPooling2D(pool_size=(2, 2)))
model_m.add(Dropout(0.2))
model_m.add(Flatten())
model_m.add(Dense(128, activation='relu'))
model_m.add(Dense(num_classes, activation='softmax'))

1. Sequential

首先创建一个空的Sequential模型,然后添加一个线性的图层堆栈:这些图层按照它们添加到模型的顺序运行。 Keras文档有几个examples of Sequential models

注意:Keras还具有用于定义复杂模型的函数API,例如多输出模型,有向非循环图或具有共享层的模型。 Google的InceptionMicrosoft Research AsiaResidual Networks是具有非线性连接结构的复杂模型的示例。

第一层必须具有关于输入形状的信息,对于MNIST(28,28,1)。 其他层从前一层的输出形状推断出它们的输入形状。 这是模型摘要的输出形状部分:

Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_6 (Conv2D)            (None, 24, 24, 32)        832       
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 12, 12, 32)        0         
_________________________________________________________________
dropout_6 (Dropout)          (None, 12, 12, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 10, 10, 64)        18496     
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
dropout_7 (Dropout)          (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 5, 5, 128)         8320      
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 2, 2, 128)         0         
_________________________________________________________________
dropout_8 (Dropout)          (None, 2, 2, 128)         0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 512)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 128)               65664     
_________________________________________________________________
dense_6 (Dense)              (None, 10)                1290      

2. Conv2D

该模型有三个Conv2D层:

Conv2D(32, (5, 5), input_shape=input_shape, activation='relu')
Conv2D(64, (3, 3), activation='relu')
Conv2D(128, (1, 1), activation='relu')

3. MaxPooling2D

MaxPooling2D(pool_size=(2, 2))

pooling层在前一层上通过m列过滤器滑动n行,将n x m值替换为其最大值。pooling滤器通常是方形的:n = m。 如下所示,最常用的2 x 2 pooling滤器将前一层的宽度和高度减半,从而减少了参数的数量,从而有助于控制过度拟合。

Malireddi的模型在每个卷积层之后都有一个pooling层,这大大减少了最终的模型大小和训练时间。

Chollet的模型在pooling之前有两个卷积层。这建议用于较大的网络,因为它允许卷积层在pooling之前开发更复杂的特征,丢弃75%的值。

Conv2DMaxPooling2D参数确定每个图层的输出形状和可训练参数的数量:

Output Shape = (input width – kernel width + 1, input height – kernel height + 1, number of filters)

您不能将3×3内核置于每行和每列的第一个和最后一个单元的中心,因此输出宽度和高度比输入小2个像素。 5×5内核可将输出宽度和高度减少4个像素。

Param # = number of filters x (kernel width x kernel height x input depth + 1 bias)

Challenge:计算Chollet架构model_c的输出形状和参数编号。

Output Shape = (input width – kernel width + 1, input height – kernel height + 1, number of filters)

  • Conv2D(32, (3, 3), input_shape=(28, 28, 1)): (28-2, 28-2, 32) = (26, 26, 32)
  • Conv2D(64, (3, 3)): (26-2, 26-2, 64) = (24, 24, 64)
  • MaxPooling2D halves the input width and height: (24/2, 24/2, 64) = (12, 12, 64)

Param # = number of filters x (kernel width x kernel height x input depth + 1 bias)

  • Conv2D(32, (3, 3), input_shape=(28, 28, 1)): 32 x (3x3x1 + 1) = 320
  • Conv2D(64, (3, 3)): 64 x (3x3x32 + 1) = 18,496

4. Dropout

Dropout(0.5)
Dropout(0.2)

dropout层通常与pooling层配对。 它将输入单位的一小部分随机设置为0。这是控制过度拟合的另一种方法:神经元不太可能受到相邻神经元的过多影响,因为它们中的任何一个都可能随机掉出网络。 这使得网络对输入中的微小变化不太敏感,因此更有可能推广到新输入。

Hands-on Machine Learning with Scikit-Learn & TensorFlowAurélienGéron将其与工作场所进行比较,在任何一天,某些人可能无法上班:每个人都必须能够完成关键任务, 并且必须与更多的同事合作。 这将使公司更具弹性,减少对任何单个工人的依赖。

5. Flatten

在将卷积层传递到完全连接的密集层之前,必须使卷积层的权重为1。

model_m.add(Dropout(0.2))
model_m.add(Flatten())
model_m.add(Dense(128, activation='relu'))

前一层的输出形状为(2,2,128),因此Flatten()的输出是一个包含512个元素的数组。

6. Dense

Dense(128, activation='relu')
Dense(num_classes, activation='softmax')

卷积层中的每个神经元使用前一层中仅少数神经元的值。 完全连接层中的每个神经元使用前一层中所有神经元的值。 此类图层的Keras名称为Dense

看看上面的模型摘要,Malireddi的第一个Dense层有512个神经元,而Chollet有9216个。两者都产生128个神经元输出层,但Chollet必须计算的参数比Malireddi的多18倍。 这是使用大部分额外训练时间的原因。

大多数CNN架构以一个或多个Dense层结束,然后是输出层。

第一个参数是图层的输出大小。 最终输出层的输出大小为10,对应于10个数字类。

softmax激活函数在10个输出类别上产生概率分布。 它是sigmoid函数的推广,它将其输入值缩放到[0,1]范围内。 对于您的MNIST分类器,softmax将10个值中的每一个都缩放为[0,1],这样它们总计为1。

您可以将sigmoid函数用于单个输出类:例如,这是一张好狗照片的概率是多少?

7. Compile

model_m.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

分类交叉熵(categorical crossentropy)损失函数测量由CNN计算的概率分布与标签的真实分布之间的距离。

优化器(optimizer)是随机梯度下降算法,它试图通过以恰当的速度跟随梯度来最小化损失函数。

准确度(Accuracy) - 正确分类的图像的分数 - 是在训练和测试期间监控的最常见度量。

8. Fit

batch_size = 256
epochs = 10
model_m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=callbacks_list,
            validation_data=(x_val, y_val), verbose=1)

批量大小(Batch size)是用于小批量随机梯度拟合的数据项的数量。选择批量大小是一个试验和错误的问题,一骰子。较小的值使得epoch需要更长的时间;较大的值可以更好地利用GPU并行性,并减少数据传输时间,但过大可能会导致内存不足。

epoch的数量也是掷骰子。每个epoch都应该改善损失和准确度测量。更多epoch应该产生更准确的模型,但训练需要更长时间。太多的epoch可能导致过度拟合。如果模型在完成所有epoch之前停止改进,则设置回调以提前停止。在notebook中,您可以重新运行fit的单元格以继续改进模型。

加载数据时,将10000个项目设置为验证数据。通过此参数可以在训练时进行验证,因此您可以监控验证损失和准确性。如果这些值比训练损失和准确度差,则表明该模型过度拟合。

9. Verbose

0 = silent, 1 = progress bar, 2 = one line per epoch.

Results - 结果

以下是我的一次训练结果:

Epoch 1/10
60000/60000 [==============================] - 106s - loss: 0.0284 - acc: 0.9909 - val_loss: 0.0216 - val_acc: 0.9940
Epoch 2/10
60000/60000 [==============================] - 100s - loss: 0.0271 - acc: 0.9911 - val_loss: 0.0199 - val_acc: 0.9942
Epoch 3/10
60000/60000 [==============================] - 102s - loss: 0.0260 - acc: 0.9914 - val_loss: 0.0228 - val_acc: 0.9931
Epoch 4/10
60000/60000 [==============================] - 101s - loss: 0.0257 - acc: 0.9913 - val_loss: 0.0211 - val_acc: 0.9935
Epoch 5/10
60000/60000 [==============================] - 101s - loss: 0.0256 - acc: 0.9916 - val_loss: 0.0222 - val_acc: 0.9928
Epoch 6/10
60000/60000 [==============================] - 100s - loss: 0.0263 - acc: 0.9913 - val_loss: 0.0178 - val_acc: 0.9950
Epoch 7/10
60000/60000 [==============================] - 87s - loss: 0.0231 - acc: 0.9920 - val_loss: 0.0212 - val_acc: 0.9932
Epoch 8/10
60000/60000 [==============================] - 76s - loss: 0.0240 - acc: 0.9922 - val_loss: 0.0212 - val_acc: 0.9935
Epoch 9/10
60000/60000 [==============================] - 76s - loss: 0.0261 - acc: 0.9916 - val_loss: 0.0220 - val_acc: 0.9934
Epoch 10/10
60000/60000 [==============================] - 76s - loss: 0.0231 - acc: 0.9925 - val_loss: 0.0203 - val_acc: 0.9935

在每个epoch,损失值应该减少,准确度值应该增加。 ModelCheckpoint回调保存了epoch1,2和6,因为epoch3,4和5中的验证损失值高于epoch2,并且在epoch6之后验证损失没有改善。训练不会提前停止,因为训练准确性从未在连续两个epoch内减少。

注意:实际上,这些结果来自20或30个epoch:我在不重置模型的情况下不止一次地运行fit单元格,因此即使在第1epoch中,损失和准确度值也已经非常好。但是您在测量中看到一些波动。例如,在epoch4,6和9中精度降低。

到目前为止,您的模型已经完成训练,所以回到编码!


Convert to Core ML Model - 转换为Core ML模型

训练步骤完成后,您应该在notebook中保存一些模型。 具有最高epoch数(和最低验证损失)的那个是最佳模型,因此在convert函数中使用该文件名。

输入以下代码,然后运行它。

output_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# For the first argument, use the filename of the newest .h5 file in the notebook folder.
coreml_mnist = coremltools.converters.keras.convert(
    'best_model.09-0.03.h5', input_names=['image'], output_names=['output'], 
    class_labels=output_labels, image_input_names='image')

在这里,您在数组中设置10个输出标签,并将其作为class_labels参数传递。 如果训练具有大量输出类的模型,请将标签放在文本文件中,每行一个标签,并将class_labels参数设置为文件名。

在参数列表中,您提供输入和输出名称,并设置image_input_names ='image',以便Core ML模型接受图像作为输入,而不是多数组。

1. Inspect Core ML model - 检查Core ML模型

输入此行,然后运行它以查看打印输出。

print(coreml_mnist)

只需检查输入类型是imageType,而不是多数组:

input {
  name: "image"
  shortDescription: "Digit image"
  type {
    imageType {
      width: 28
      height: 28
      colorSpace: GRAYSCALE
    }
  }
}

2. Add Metadata for Xcode - 为Xcode添加元数据

现在添加以下内容,替换前两个项目的自己的名称和许可证信息,然后运行它。

coreml_mnist.author = 'raywenderlich.com'
coreml_mnist.license = 'Razeware'
coreml_mnist.short_description = 'Image based digit recognition (MNIST)'
coreml_mnist.input_description['image'] = 'Digit image'
coreml_mnist.output_description['output'] = 'Probability of each digit'
coreml_mnist.output_description['classLabel'] = 'Labels of digits'

在Xcode的项目导航器中选择模型时会出现此信息。

3. Save the Core ML Model - 保存Core ML模型

最后,添加以下内容并运行它。

coreml_mnist.save('MNISTClassifier.mlmodel')

这会将mlmodel文件保存在notebook文件夹中。

恭喜,您现在拥有一个Core ML模型,可以对手写数字进行分类! 是时候在iOS应用程序中使用它了。


Use Model in iOS App - 在iOS App中使用Model

1. Step 1. Drag the model into the app - 步骤1.将模型拖到应用程序中:

在Xcode中打开入门应用程序,并将Finders中的MNISTClassifier.mlmodel拖到项目的Project导航器中。 选择它以查看您添加的元数据:

如果不是Automatically generated Swift model class,而是建立项目来生成模型类,请继续执行此操作。

2. Step 2. Import the CoreML and Vision frameworks: - 步骤2.导入CoreML和Vision框架:

打开ViewController.swift,导入两个框架,就在导入UIKit下面:

import CoreML
import Vision

3. Step 3. Create VNCoreMLModel and VNCoreMLRequest objects: - 步骤3.创建VNCoreMLModel和VNCoreMLRequest对象:

outlets下面添加以下代码:

lazy var classificationRequest: VNCoreMLRequest = {
  // Load the ML model through its generated class and create a Vision request for it.
  do {
    let model = try VNCoreMLModel(for: MNISTClassifier().model)
    return VNCoreMLRequest(model: model, completionHandler: handleClassification)
  } catch {
    fatalError("Can't load Vision ML model: \(error).")
  }
}()

func handleClassification(request: VNRequest, error: Error?) {
  guard let observations = request.results as? [VNClassificationObservation]
    else { fatalError("Unexpected result type from VNCoreMLRequest.") }
  guard let best = observations.first
    else { fatalError("Can't get best result.") }

  DispatchQueue.main.async {
    self.predictLabel.text = best.identifier
    self.predictLabel.isHidden = false
  }
}

请求对象适用于步骤4中的处理程序传递给它的任何图像,因此您只需将其定义一次,作为一个lazy var

请求对象的完成处理程序接收requesterror对象。 您检查request.results是一个VNClassificationObservation对象的数组,这是当Core ML模型是分类器而不是预测器或图像处理器时Vision框架返回的对象。

VNClassificationObservation对象有两个属性:identifier - 一个String - 和confidence - 一个介于0和1之间的数字 - 分类正确的概率。 您获取第一个结果,该结果具有最高置信度值,并调度回主队列以更新predictLabel。 分类工作发生在主队列之外,因为它可能很慢。

4. Step 4. Create and run a VNImageRequestHandler: - 步骤4.创建并运行VNImageRequestHandler:

找到predictTapped(),并使用以下代码替换print语句:

let ciImage = CIImage(cgImage: inputImage)
let handler = VNImageRequestHandler(ciImage: ciImage)
do {
  try handler.perform([classificationRequest])
} catch {
  print(error)
}

您可以从inputImage创建CIImage,然后为此ciImage创建VNImageRequestHandler对象,并在VNCoreMLRequest对象数组上运行处理程序 - 在本例中,只是您在步骤3中创建的一个请求对象。

建立并运行。 在绘图区域的中心绘制一个数字,然后点击Predict。 点按Clear再试一次。

较大的绘制往往效果更好,但模型常常遇到'7'和'4'的问题。 毫不奇怪,因为MNIST数据的PCA visualization显示7s和4s聚集在9s:

注意:Malireddi表示Vision框架使用了20%的CPU,因此his app包含一个将UIImage对象转换为CVPixelBuffer格式的扩展。

如果您不使用Vision,请在将Keras模型转换为Core ML时将image_scale = 1 / 255.0作为参数:Keras模型训练灰度值在[0,1]范围内的图像,CVPixelBuffer值为 在[0,255]范围内。

感谢 Sri Raghu M, Matthijs HollemansHon Weng Chong的有益讨论!

资源

进一步阅读


源码

1. Swift

看下工程文档结构

接着,看一下sb内容

1. ViewController.swift
import UIKit
import CoreML
import Vision

class ViewController: UIViewController {

  @IBOutlet weak var drawView: DrawView!
  @IBOutlet weak var predictLabel: UILabel!

  // DONE: Define lazy var classificationRequest
  lazy var classificationRequest: VNCoreMLRequest = {
    // Load the ML model through its generated class and create a Vision request for it.
    do {
      let model = try VNCoreMLModel(for: MNISTClassifier().model)
      return VNCoreMLRequest(model: model, completionHandler: self.handleClassification)
    } catch {
      fatalError("Can't load Vision ML model: \(error).")
    }
  }()

  func handleClassification(request: VNRequest, error: Error?) {
    guard let observations = request.results as? [VNClassificationObservation]
      else { fatalError("Unexpected result type from VNCoreMLRequest.") }
    guard let best = observations.first
      else { fatalError("Can't get best result.") }

    DispatchQueue.main.async {
      self.predictLabel.text = best.identifier
      self.predictLabel.isHidden = false
    }
  }

  override func viewDidLoad() {
    super.viewDidLoad()
    predictLabel.isHidden = true
  }

  @IBAction func clearTapped() {
    drawView.lines = []
    drawView.setNeedsDisplay()
    predictLabel.isHidden = true
  }

  @IBAction func predictTapped() {
    guard let context = drawView.getViewContext(),
      let inputImage = context.makeImage()
      else { fatalError("Get context or make image failed.") }
    // DONE: Perform request on model
    let ciImage = CIImage(cgImage: inputImage)
    let handler = VNImageRequestHandler(ciImage: ciImage)
    do {
      try handler.perform([classificationRequest])
    } catch {
      print(error)
    }
  }

}
2. DrawView.swift
// Code taken with inspiration from Apple's Metal-2 sample MPSCNNHelloWorld
import UIKit

/**
 This class is used to handle the drawing in the DigitView so we can get user input digit,
 This class doesn't really have an MPS or Metal going in it, it is just used to get user input
 */
class DrawView: UIView {
    
    // some parameters of how thick a line to draw 15 seems to work
    // and we have white drawings on black background just like MNIST needs its input
    var linewidth = CGFloat(15) { didSet { setNeedsDisplay() } }
    var color = UIColor.white { didSet { setNeedsDisplay() } }
    
    // we will keep touches made by user in view in these as a record so we can draw them.
    var lines: [Line] = []
    var lastPoint: CGPoint!
    
    override func touchesBegan(_ touches: Set<UITouch>, with event: UIEvent?) {
        lastPoint = touches.first!.location(in: self)
    }
    
    override func touchesMoved(_ touches: Set<UITouch>, with event: UIEvent?) {
        let newPoint = touches.first!.location(in: self)
        // keep all lines drawn by user as touch in record so we can draw them in view
        lines.append(Line(start: lastPoint, end: newPoint))
        lastPoint = newPoint
        // make a draw call
        setNeedsDisplay()
    }
    
    override func draw(_ rect: CGRect) {
        super.draw(rect)
        
        let drawPath = UIBezierPath()
        drawPath.lineCapStyle = .round
        
        for line in lines{
            drawPath.move(to: line.start)
            drawPath.addLine(to: line.end)
        }
        
        drawPath.lineWidth = linewidth
        color.set()
        drawPath.stroke()
    }
    
    
    /**
     This function gets the pixel data of the view so we can put it in MTLTexture
     
     - Returns:
     Void
     */
    func getViewContext() -> CGContext? {
        // our network takes in only grayscale images as input
        let colorSpace:CGColorSpace = CGColorSpaceCreateDeviceGray()
        
        // we have 3 channels no alpha value put in the network
        let bitmapInfo = CGImageAlphaInfo.none.rawValue
        
        // this is where our view pixel data will go in once we make the render call
        let context = CGContext(data: nil, width: 28, height: 28, bitsPerComponent: 8, bytesPerRow: 28, space: colorSpace, bitmapInfo: bitmapInfo)
        
        // scale and translate so we have the full digit and in MNIST standard size 28x28
        context!.translateBy(x: 0 , y: 28)
        context!.scaleBy(x: 28/self.frame.size.width, y: -28/self.frame.size.height)
        
        // put view pixel data in context
        self.layer.render(in: context!)
        
        return context
    }
}

/**
 2 points can give a line and this class is just for that purpose, it keeps a record of a line
 */
class Line{
    var start, end: CGPoint
    
    init(start: CGPoint, end: CGPoint) {
        self.start = start
        self.end   = end
    }
}

后记

本篇主要讲述了使用Keras和Core ML开始机器学习,感兴趣的给个赞或者关注~~~

上一篇下一篇

猜你喜欢

热点阅读