Введение в TensorFlow JS

Запуск TF, создание и обучение модели.

На windows пока не запускается, чтобы запустить нужна WSL (Windows Subsistem Linux)
nodejs

const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

// Generate some synthetic data for training.
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

// Train the model using the data.
model.fit(xs, ys, {epochs: 10}).then(() => {
  // Use the model to do inference on a data point the model hasn't seen before:
  model.predict(tf.tensor2d([5], [1, 1])).print();
});

В VS Code можно указать для отладки «useWSL»: true
Тогда будет по F5 запускаться через WSL
.vscode/launch.json

{
    "version": "0.2.0",
    "configurations": [{
        "type": "node",
        "request": "launch",
        "name": "Launch Program",
        "program": "${workspaceFolder}\\index.js",
        "useWSL": true
    }]
}

Запуск в браузере

script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.14.1/dist/tf.min.js"

// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

// Generate some synthetic data for training.
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);

// Train the model using the data.
model.fit(xs, ys, {epochs: 10}).then(() => {
// Use the model to do inference on a data point the model hasn't seen before:
// Open the browser devtools to see the output
model.predict(tf.tensor2d([5], [1, 1])).print();
});

Сохранение модели

const tf = require('@tensorflow/tfjs');
require('@tensorflow/tfjs-node');

const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

model.save('file://./model-1a');

Но при этом в зависимостях должно быть только tfjs-node
Если добавить tfjs то сохранить не дает, выходит
UnhandledPromiseRejectionWarning: Error: Cannot find any save handlers for URL ‘file:///./model’
./ — текущая директория

"dependencies": {
    "@tensorflow/tfjs-node": "^0.1.21"
}

Загрузка модели

model = await tf.loadModel(‘file://./tmp/model/model.json’);

Error: browserHTTPRequest is not supported outside the web browser.

Добавляем
global.fetch = require(‘node-fetch’);

Print Friendly, PDF & Email

Добавить комментарий