Agriculture-front-end/public/libs/tensorflow/3.9.0/tf.js
2023-06-22 06:50:23 +08:00

127719 lines
4.1 MiB
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function (global, factory) {
typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) :
typeof define === 'function' && define.amd ? define(['exports'], factory) :
(global = global || self, factory(global.tf = global.tf || {}));
}(this, (function (exports) { 'use strict';
var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
function unwrapExports (x) {
return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
}
function createCommonjsModule(fn, module) {
return module = { exports: {} }, fn(module, module.exports), module.exports;
}
function getCjsExportFromNamespace (n) {
return n && n['default'] || n;
}
function commonjsRequire () {
throw new Error('Dynamic requires are not currently supported by @rollup/plugin-commonjs');
}
var check = function check(it) {
return it && it.Math == Math && it;
}; // https://github.com/zloirock/core-js/issues/86#issuecomment-115759028
var global_1 =
/* global globalThis -- safe */
check(typeof globalThis == 'object' && globalThis) || check(typeof window == 'object' && window) || check(typeof self == 'object' && self) || check(typeof commonjsGlobal == 'object' && commonjsGlobal) || // eslint-disable-next-line no-new-func -- fallback
function () {
return this;
}() || Function('return this')();
var fails = function fails(exec) {
try {
return !!exec();
} catch (error) {
return true;
}
};
var descriptors = !fails(function () {
return Object.defineProperty({}, 1, {
get: function get() {
return 7;
}
})[1] != 7;
});
'use strict';
var nativePropertyIsEnumerable = {}.propertyIsEnumerable;
var getOwnPropertyDescriptor = Object.getOwnPropertyDescriptor; // Nashorn ~ JDK8 bug
var NASHORN_BUG = getOwnPropertyDescriptor && !nativePropertyIsEnumerable.call({
1: 2
}, 1); // `Object.prototype.propertyIsEnumerable` method implementation
// https://tc39.es/ecma262/#sec-object.prototype.propertyisenumerable
var f = NASHORN_BUG ? function propertyIsEnumerable(V) {
var descriptor = getOwnPropertyDescriptor(this, V);
return !!descriptor && descriptor.enumerable;
} : nativePropertyIsEnumerable;
var objectPropertyIsEnumerable = {
f: f
};
var createPropertyDescriptor = function createPropertyDescriptor(bitmap, value) {
return {
enumerable: !(bitmap & 1),
configurable: !(bitmap & 2),
writable: !(bitmap & 4),
value: value
};
};
var toString = {}.toString;
var classofRaw = function classofRaw(it) {
return toString.call(it).slice(8, -1);
};
var split = ''.split; // fallback for non-array-like ES3 and non-enumerable old V8 strings
var indexedObject = fails(function () {
// throws an error in rhino, see https://github.com/mozilla/rhino/issues/346
// eslint-disable-next-line no-prototype-builtins -- safe
return !Object('z').propertyIsEnumerable(0);
}) ? function (it) {
return classofRaw(it) == 'String' ? split.call(it, '') : Object(it);
} : Object;
// `RequireObjectCoercible` abstract operation
// https://tc39.es/ecma262/#sec-requireobjectcoercible
var requireObjectCoercible = function requireObjectCoercible(it) {
if (it == undefined) throw TypeError("Can't call method on " + it);
return it;
};
var toIndexedObject = function toIndexedObject(it) {
return indexedObject(requireObjectCoercible(it));
};
var isObject = function isObject(it) {
return typeof it === 'object' ? it !== null : typeof it === 'function';
};
// https://tc39.es/ecma262/#sec-toprimitive
// instead of the ES6 spec version, we didn't implement @@toPrimitive case
// and the second argument - flag - preferred type is a string
var toPrimitive = function toPrimitive(input, PREFERRED_STRING) {
if (!isObject(input)) return input;
var fn, val;
if (PREFERRED_STRING && typeof (fn = input.toString) == 'function' && !isObject(val = fn.call(input))) return val;
if (typeof (fn = input.valueOf) == 'function' && !isObject(val = fn.call(input))) return val;
if (!PREFERRED_STRING && typeof (fn = input.toString) == 'function' && !isObject(val = fn.call(input))) return val;
throw TypeError("Can't convert object to primitive value");
};
var hasOwnProperty = {}.hasOwnProperty;
var has = function has(it, key) {
return hasOwnProperty.call(it, key);
};
var document$1 = global_1.document; // typeof document.createElement is 'object' in old IE
var EXISTS = isObject(document$1) && isObject(document$1.createElement);
var documentCreateElement = function documentCreateElement(it) {
return EXISTS ? document$1.createElement(it) : {};
};
var ie8DomDefine = !descriptors && !fails(function () {
return Object.defineProperty(documentCreateElement('div'), 'a', {
get: function get() {
return 7;
}
}).a != 7;
});
var nativeGetOwnPropertyDescriptor = Object.getOwnPropertyDescriptor; // `Object.getOwnPropertyDescriptor` method
// https://tc39.es/ecma262/#sec-object.getownpropertydescriptor
var f$1 = descriptors ? nativeGetOwnPropertyDescriptor : function getOwnPropertyDescriptor(O, P) {
O = toIndexedObject(O);
P = toPrimitive(P, true);
if (ie8DomDefine) try {
return nativeGetOwnPropertyDescriptor(O, P);
} catch (error) {
/* empty */
}
if (has(O, P)) return createPropertyDescriptor(!objectPropertyIsEnumerable.f.call(O, P), O[P]);
};
var objectGetOwnPropertyDescriptor = {
f: f$1
};
var anObject = function anObject(it) {
if (!isObject(it)) {
throw TypeError(String(it) + ' is not an object');
}
return it;
};
var nativeDefineProperty = Object.defineProperty; // `Object.defineProperty` method
// https://tc39.es/ecma262/#sec-object.defineproperty
var f$2 = descriptors ? nativeDefineProperty : function defineProperty(O, P, Attributes) {
anObject(O);
P = toPrimitive(P, true);
anObject(Attributes);
if (ie8DomDefine) try {
return nativeDefineProperty(O, P, Attributes);
} catch (error) {
/* empty */
}
if ('get' in Attributes || 'set' in Attributes) throw TypeError('Accessors not supported');
if ('value' in Attributes) O[P] = Attributes.value;
return O;
};
var objectDefineProperty = {
f: f$2
};
var createNonEnumerableProperty = descriptors ? function (object, key, value) {
return objectDefineProperty.f(object, key, createPropertyDescriptor(1, value));
} : function (object, key, value) {
object[key] = value;
return object;
};
var setGlobal = function setGlobal(key, value) {
try {
createNonEnumerableProperty(global_1, key, value);
} catch (error) {
global_1[key] = value;
}
return value;
};
var SHARED = '__core-js_shared__';
var store = global_1[SHARED] || setGlobal(SHARED, {});
var sharedStore = store;
var functionToString = Function.toString; // this helper broken in `3.4.1-3.4.4`, so we can't use `shared` helper
if (typeof sharedStore.inspectSource != 'function') {
sharedStore.inspectSource = function (it) {
return functionToString.call(it);
};
}
var inspectSource = sharedStore.inspectSource;
var WeakMap$1 = global_1.WeakMap;
var nativeWeakMap = typeof WeakMap$1 === 'function' && /native code/.test(inspectSource(WeakMap$1));
var isPure = false;
var shared = createCommonjsModule(function (module) {
(module.exports = function (key, value) {
return sharedStore[key] || (sharedStore[key] = value !== undefined ? value : {});
})('versions', []).push({
version: '3.9.1',
mode: isPure ? 'pure' : 'global',
copyright: '© 2021 Denis Pushkarev (zloirock.ru)'
});
});
var id = 0;
var postfix = Math.random();
var uid = function uid(key) {
return 'Symbol(' + String(key === undefined ? '' : key) + ')_' + (++id + postfix).toString(36);
};
var keys = shared('keys');
var sharedKey = function sharedKey(key) {
return keys[key] || (keys[key] = uid(key));
};
var hiddenKeys = {};
var WeakMap$2 = global_1.WeakMap;
var set, get, has$1;
var enforce = function enforce(it) {
return has$1(it) ? get(it) : set(it, {});
};
var getterFor = function getterFor(TYPE) {
return function (it) {
var state;
if (!isObject(it) || (state = get(it)).type !== TYPE) {
throw TypeError('Incompatible receiver, ' + TYPE + ' required');
}
return state;
};
};
if (nativeWeakMap) {
var store$1 = sharedStore.state || (sharedStore.state = new WeakMap$2());
var wmget = store$1.get;
var wmhas = store$1.has;
var wmset = store$1.set;
set = function set(it, metadata) {
metadata.facade = it;
wmset.call(store$1, it, metadata);
return metadata;
};
get = function get(it) {
return wmget.call(store$1, it) || {};
};
has$1 = function has(it) {
return wmhas.call(store$1, it);
};
} else {
var STATE = sharedKey('state');
hiddenKeys[STATE] = true;
set = function set(it, metadata) {
metadata.facade = it;
createNonEnumerableProperty(it, STATE, metadata);
return metadata;
};
get = function get(it) {
return has(it, STATE) ? it[STATE] : {};
};
has$1 = function has$1(it) {
return has(it, STATE);
};
}
var internalState = {
set: set,
get: get,
has: has$1,
enforce: enforce,
getterFor: getterFor
};
var internalState_1 = internalState.set;
var internalState_2 = internalState.get;
var internalState_3 = internalState.has;
var internalState_4 = internalState.enforce;
var internalState_5 = internalState.getterFor;
var redefine = createCommonjsModule(function (module) {
var getInternalState = internalState.get;
var enforceInternalState = internalState.enforce;
var TEMPLATE = String(String).split('String');
(module.exports = function (O, key, value, options) {
var unsafe = options ? !!options.unsafe : false;
var simple = options ? !!options.enumerable : false;
var noTargetGet = options ? !!options.noTargetGet : false;
var state;
if (typeof value == 'function') {
if (typeof key == 'string' && !has(value, 'name')) {
createNonEnumerableProperty(value, 'name', key);
}
state = enforceInternalState(value);
if (!state.source) {
state.source = TEMPLATE.join(typeof key == 'string' ? key : '');
}
}
if (O === global_1) {
if (simple) O[key] = value;else setGlobal(key, value);
return;
} else if (!unsafe) {
delete O[key];
} else if (!noTargetGet && O[key]) {
simple = true;
}
if (simple) O[key] = value;else createNonEnumerableProperty(O, key, value); // add fake Function#toString for correct work wrapped methods / constructors with methods like LoDash isNative
})(Function.prototype, 'toString', function toString() {
return typeof this == 'function' && getInternalState(this).source || inspectSource(this);
});
});
var path = global_1;
var aFunction = function aFunction(variable) {
return typeof variable == 'function' ? variable : undefined;
};
var getBuiltIn = function getBuiltIn(namespace, method) {
return arguments.length < 2 ? aFunction(path[namespace]) || aFunction(global_1[namespace]) : path[namespace] && path[namespace][method] || global_1[namespace] && global_1[namespace][method];
};
var ceil = Math.ceil;
var floor = Math.floor; // `ToInteger` abstract operation
// https://tc39.es/ecma262/#sec-tointeger
var toInteger = function toInteger(argument) {
return isNaN(argument = +argument) ? 0 : (argument > 0 ? floor : ceil)(argument);
};
var min = Math.min; // `ToLength` abstract operation
// https://tc39.es/ecma262/#sec-tolength
var toLength = function toLength(argument) {
return argument > 0 ? min(toInteger(argument), 0x1FFFFFFFFFFFFF) : 0; // 2 ** 53 - 1 == 9007199254740991
};
var max = Math.max;
var min$1 = Math.min; // Helper for a popular repeating case of the spec:
// Let integer be ? ToInteger(index).
// If integer < 0, let result be max((length + integer), 0); else let result be min(integer, length).
var toAbsoluteIndex = function toAbsoluteIndex(index, length) {
var integer = toInteger(index);
return integer < 0 ? max(integer + length, 0) : min$1(integer, length);
};
var createMethod = function createMethod(IS_INCLUDES) {
return function ($this, el, fromIndex) {
var O = toIndexedObject($this);
var length = toLength(O.length);
var index = toAbsoluteIndex(fromIndex, length);
var value; // Array#includes uses SameValueZero equality algorithm
// eslint-disable-next-line no-self-compare -- NaN check
if (IS_INCLUDES && el != el) while (length > index) {
value = O[index++]; // eslint-disable-next-line no-self-compare -- NaN check
if (value != value) return true; // Array#indexOf ignores holes, Array#includes - not
} else for (; length > index; index++) {
if ((IS_INCLUDES || index in O) && O[index] === el) return IS_INCLUDES || index || 0;
}
return !IS_INCLUDES && -1;
};
};
var arrayIncludes = {
// `Array.prototype.includes` method
// https://tc39.es/ecma262/#sec-array.prototype.includes
includes: createMethod(true),
// `Array.prototype.indexOf` method
// https://tc39.es/ecma262/#sec-array.prototype.indexof
indexOf: createMethod(false)
};
var arrayIncludes_1 = arrayIncludes.includes;
var arrayIncludes_2 = arrayIncludes.indexOf;
var indexOf = arrayIncludes.indexOf;
var objectKeysInternal = function objectKeysInternal(object, names) {
var O = toIndexedObject(object);
var i = 0;
var result = [];
var key;
for (key in O) {
!has(hiddenKeys, key) && has(O, key) && result.push(key);
} // Don't enum bug & hidden keys
while (names.length > i) {
if (has(O, key = names[i++])) {
~indexOf(result, key) || result.push(key);
}
}
return result;
};
// IE8- don't enum bug keys
var enumBugKeys = ['constructor', 'hasOwnProperty', 'isPrototypeOf', 'propertyIsEnumerable', 'toLocaleString', 'toString', 'valueOf'];
var hiddenKeys$1 = enumBugKeys.concat('length', 'prototype'); // `Object.getOwnPropertyNames` method
// https://tc39.es/ecma262/#sec-object.getownpropertynames
var f$3 = Object.getOwnPropertyNames || function getOwnPropertyNames(O) {
return objectKeysInternal(O, hiddenKeys$1);
};
var objectGetOwnPropertyNames = {
f: f$3
};
var f$4 = Object.getOwnPropertySymbols;
var objectGetOwnPropertySymbols = {
f: f$4
};
var ownKeys = getBuiltIn('Reflect', 'ownKeys') || function ownKeys(it) {
var keys = objectGetOwnPropertyNames.f(anObject(it));
var getOwnPropertySymbols = objectGetOwnPropertySymbols.f;
return getOwnPropertySymbols ? keys.concat(getOwnPropertySymbols(it)) : keys;
};
var copyConstructorProperties = function copyConstructorProperties(target, source) {
var keys = ownKeys(source);
var defineProperty = objectDefineProperty.f;
var getOwnPropertyDescriptor = objectGetOwnPropertyDescriptor.f;
for (var i = 0; i < keys.length; i++) {
var key = keys[i];
if (!has(target, key)) defineProperty(target, key, getOwnPropertyDescriptor(source, key));
}
};
var replacement = /#|\.prototype\./;
var isForced = function isForced(feature, detection) {
var value = data[normalize(feature)];
return value == POLYFILL ? true : value == NATIVE ? false : typeof detection == 'function' ? fails(detection) : !!detection;
};
var normalize = isForced.normalize = function (string) {
return String(string).replace(replacement, '.').toLowerCase();
};
var data = isForced.data = {};
var NATIVE = isForced.NATIVE = 'N';
var POLYFILL = isForced.POLYFILL = 'P';
var isForced_1 = isForced;
var getOwnPropertyDescriptor$1 = objectGetOwnPropertyDescriptor.f;
/*
options.target - name of the target object
options.global - target is the global object
options.stat - export as static methods of target
options.proto - export as prototype methods of target
options.real - real prototype method for the `pure` version
options.forced - export even if the native feature is available
options.bind - bind methods to the target, required for the `pure` version
options.wrap - wrap constructors to preventing global pollution, required for the `pure` version
options.unsafe - use the simple assignment of property instead of delete + defineProperty
options.sham - add a flag to not completely full polyfills
options.enumerable - export as enumerable property
options.noTargetGet - prevent calling a getter on target
*/
var _export = function _export(options, source) {
var TARGET = options.target;
var GLOBAL = options.global;
var STATIC = options.stat;
var FORCED, target, key, targetProperty, sourceProperty, descriptor;
if (GLOBAL) {
target = global_1;
} else if (STATIC) {
target = global_1[TARGET] || setGlobal(TARGET, {});
} else {
target = (global_1[TARGET] || {}).prototype;
}
if (target) for (key in source) {
sourceProperty = source[key];
if (options.noTargetGet) {
descriptor = getOwnPropertyDescriptor$1(target, key);
targetProperty = descriptor && descriptor.value;
} else targetProperty = target[key];
FORCED = isForced_1(GLOBAL ? key : TARGET + (STATIC ? '.' : '#') + key, options.forced); // contained in target
if (!FORCED && targetProperty !== undefined) {
if (typeof sourceProperty === typeof targetProperty) continue;
copyConstructorProperties(sourceProperty, targetProperty);
} // add a flag to not completely full polyfills
if (options.sham || targetProperty && targetProperty.sham) {
createNonEnumerableProperty(sourceProperty, 'sham', true);
} // extend global
redefine(target, key, sourceProperty, options);
}
};
var engineIsNode = classofRaw(global_1.process) == 'process';
var engineUserAgent = getBuiltIn('navigator', 'userAgent') || '';
var process$1 = global_1.process;
var versions = process$1 && process$1.versions;
var v8 = versions && versions.v8;
var match, version;
if (v8) {
match = v8.split('.');
version = match[0] + match[1];
} else if (engineUserAgent) {
match = engineUserAgent.match(/Edge\/(\d+)/);
if (!match || match[1] >= 74) {
match = engineUserAgent.match(/Chrome\/(\d+)/);
if (match) version = match[1];
}
}
var engineV8Version = version && +version;
var nativeSymbol = !!Object.getOwnPropertySymbols && !fails(function () {
/* global Symbol -- required for testing */
return !Symbol.sham && ( // Chrome 38 Symbol has incorrect toString conversion
// Chrome 38-40 symbols are not inherited from DOM collections prototypes to instances
engineIsNode ? engineV8Version === 38 : engineV8Version > 37 && engineV8Version < 41);
});
var useSymbolAsUid = nativeSymbol
/* global Symbol -- safe */
&& !Symbol.sham && typeof Symbol.iterator == 'symbol';
// https://tc39.es/ecma262/#sec-isarray
var isArray = Array.isArray || function isArray(arg) {
return classofRaw(arg) == 'Array';
};
// https://tc39.es/ecma262/#sec-toobject
var toObject = function toObject(argument) {
return Object(requireObjectCoercible(argument));
};
// https://tc39.es/ecma262/#sec-object.keys
var objectKeys = Object.keys || function keys(O) {
return objectKeysInternal(O, enumBugKeys);
};
// https://tc39.es/ecma262/#sec-object.defineproperties
var objectDefineProperties = descriptors ? Object.defineProperties : function defineProperties(O, Properties) {
anObject(O);
var keys = objectKeys(Properties);
var length = keys.length;
var index = 0;
var key;
while (length > index) {
objectDefineProperty.f(O, key = keys[index++], Properties[key]);
}
return O;
};
var html = getBuiltIn('document', 'documentElement');
var GT = '>';
var LT = '<';
var PROTOTYPE = 'prototype';
var SCRIPT = 'script';
var IE_PROTO = sharedKey('IE_PROTO');
var EmptyConstructor = function EmptyConstructor() {
/* empty */
};
var scriptTag = function scriptTag(content) {
return LT + SCRIPT + GT + content + LT + '/' + SCRIPT + GT;
}; // Create object with fake `null` prototype: use ActiveX Object with cleared prototype
var NullProtoObjectViaActiveX = function NullProtoObjectViaActiveX(activeXDocument) {
activeXDocument.write(scriptTag(''));
activeXDocument.close();
var temp = activeXDocument.parentWindow.Object;
activeXDocument = null; // avoid memory leak
return temp;
}; // Create object with fake `null` prototype: use iframe Object with cleared prototype
var NullProtoObjectViaIFrame = function NullProtoObjectViaIFrame() {
// Thrash, waste and sodomy: IE GC bug
var iframe = documentCreateElement('iframe');
var JS = 'java' + SCRIPT + ':';
var iframeDocument;
iframe.style.display = 'none';
html.appendChild(iframe); // https://github.com/zloirock/core-js/issues/475
iframe.src = String(JS);
iframeDocument = iframe.contentWindow.document;
iframeDocument.open();
iframeDocument.write(scriptTag('document.F=Object'));
iframeDocument.close();
return iframeDocument.F;
}; // Check for document.domain and active x support
// No need to use active x approach when document.domain is not set
// see https://github.com/es-shims/es5-shim/issues/150
// variation of https://github.com/kitcambridge/es5-shim/commit/4f738ac066346
// avoid IE GC bug
var activeXDocument;
var _NullProtoObject = function NullProtoObject() {
try {
/* global ActiveXObject -- old IE */
activeXDocument = document.domain && new ActiveXObject('htmlfile');
} catch (error) {
/* ignore */
}
_NullProtoObject = activeXDocument ? NullProtoObjectViaActiveX(activeXDocument) : NullProtoObjectViaIFrame();
var length = enumBugKeys.length;
while (length--) {
delete _NullProtoObject[PROTOTYPE][enumBugKeys[length]];
}
return _NullProtoObject();
};
hiddenKeys[IE_PROTO] = true; // `Object.create` method
// https://tc39.es/ecma262/#sec-object.create
var objectCreate = Object.create || function create(O, Properties) {
var result;
if (O !== null) {
EmptyConstructor[PROTOTYPE] = anObject(O);
result = new EmptyConstructor();
EmptyConstructor[PROTOTYPE] = null; // add "__proto__" for Object.getPrototypeOf polyfill
result[IE_PROTO] = O;
} else result = _NullProtoObject();
return Properties === undefined ? result : objectDefineProperties(result, Properties);
};
var nativeGetOwnPropertyNames = objectGetOwnPropertyNames.f;
var toString$1 = {}.toString;
var windowNames = typeof window == 'object' && window && Object.getOwnPropertyNames ? Object.getOwnPropertyNames(window) : [];
var getWindowNames = function getWindowNames(it) {
try {
return nativeGetOwnPropertyNames(it);
} catch (error) {
return windowNames.slice();
}
}; // fallback for IE11 buggy Object.getOwnPropertyNames with iframe and window
var f$5 = function getOwnPropertyNames(it) {
return windowNames && toString$1.call(it) == '[object Window]' ? getWindowNames(it) : nativeGetOwnPropertyNames(toIndexedObject(it));
};
var objectGetOwnPropertyNamesExternal = {
f: f$5
};
var WellKnownSymbolsStore = shared('wks');
var Symbol$1 = global_1.Symbol;
var createWellKnownSymbol = useSymbolAsUid ? Symbol$1 : Symbol$1 && Symbol$1.withoutSetter || uid;
var wellKnownSymbol = function wellKnownSymbol(name) {
if (!has(WellKnownSymbolsStore, name) || !(nativeSymbol || typeof WellKnownSymbolsStore[name] == 'string')) {
if (nativeSymbol && has(Symbol$1, name)) {
WellKnownSymbolsStore[name] = Symbol$1[name];
} else {
WellKnownSymbolsStore[name] = createWellKnownSymbol('Symbol.' + name);
}
}
return WellKnownSymbolsStore[name];
};
var f$6 = wellKnownSymbol;
var wellKnownSymbolWrapped = {
f: f$6
};
var defineProperty = objectDefineProperty.f;
var defineWellKnownSymbol = function defineWellKnownSymbol(NAME) {
var Symbol = path.Symbol || (path.Symbol = {});
if (!has(Symbol, NAME)) defineProperty(Symbol, NAME, {
value: wellKnownSymbolWrapped.f(NAME)
});
};
var defineProperty$1 = objectDefineProperty.f;
var TO_STRING_TAG = wellKnownSymbol('toStringTag');
var setToStringTag = function setToStringTag(it, TAG, STATIC) {
if (it && !has(it = STATIC ? it : it.prototype, TO_STRING_TAG)) {
defineProperty$1(it, TO_STRING_TAG, {
configurable: true,
value: TAG
});
}
};
var aFunction$1 = function aFunction(it) {
if (typeof it != 'function') {
throw TypeError(String(it) + ' is not a function');
}
return it;
};
var functionBindContext = function functionBindContext(fn, that, length) {
aFunction$1(fn);
if (that === undefined) return fn;
switch (length) {
case 0:
return function () {
return fn.call(that);
};
case 1:
return function (a) {
return fn.call(that, a);
};
case 2:
return function (a, b) {
return fn.call(that, a, b);
};
case 3:
return function (a, b, c) {
return fn.call(that, a, b, c);
};
}
return function ()
/* ...args */
{
return fn.apply(that, arguments);
};
};
var SPECIES = wellKnownSymbol('species'); // `ArraySpeciesCreate` abstract operation
// https://tc39.es/ecma262/#sec-arrayspeciescreate
var arraySpeciesCreate = function arraySpeciesCreate(originalArray, length) {
var C;
if (isArray(originalArray)) {
C = originalArray.constructor; // cross-realm fallback
if (typeof C == 'function' && (C === Array || isArray(C.prototype))) C = undefined;else if (isObject(C)) {
C = C[SPECIES];
if (C === null) C = undefined;
}
}
return new (C === undefined ? Array : C)(length === 0 ? 0 : length);
};
var push = [].push; // `Array.prototype.{ forEach, map, filter, some, every, find, findIndex, filterOut }` methods implementation
var createMethod$1 = function createMethod(TYPE) {
var IS_MAP = TYPE == 1;
var IS_FILTER = TYPE == 2;
var IS_SOME = TYPE == 3;
var IS_EVERY = TYPE == 4;
var IS_FIND_INDEX = TYPE == 6;
var IS_FILTER_OUT = TYPE == 7;
var NO_HOLES = TYPE == 5 || IS_FIND_INDEX;
return function ($this, callbackfn, that, specificCreate) {
var O = toObject($this);
var self = indexedObject(O);
var boundFunction = functionBindContext(callbackfn, that, 3);
var length = toLength(self.length);
var index = 0;
var create = specificCreate || arraySpeciesCreate;
var target = IS_MAP ? create($this, length) : IS_FILTER || IS_FILTER_OUT ? create($this, 0) : undefined;
var value, result;
for (; length > index; index++) {
if (NO_HOLES || index in self) {
value = self[index];
result = boundFunction(value, index, O);
if (TYPE) {
if (IS_MAP) target[index] = result; // map
else if (result) switch (TYPE) {
case 3:
return true;
// some
case 5:
return value;
// find
case 6:
return index;
// findIndex
case 2:
push.call(target, value);
// filter
} else switch (TYPE) {
case 4:
return false;
// every
case 7:
push.call(target, value);
// filterOut
}
}
}
}
return IS_FIND_INDEX ? -1 : IS_SOME || IS_EVERY ? IS_EVERY : target;
};
};
var arrayIteration = {
// `Array.prototype.forEach` method
// https://tc39.es/ecma262/#sec-array.prototype.foreach
forEach: createMethod$1(0),
// `Array.prototype.map` method
// https://tc39.es/ecma262/#sec-array.prototype.map
map: createMethod$1(1),
// `Array.prototype.filter` method
// https://tc39.es/ecma262/#sec-array.prototype.filter
filter: createMethod$1(2),
// `Array.prototype.some` method
// https://tc39.es/ecma262/#sec-array.prototype.some
some: createMethod$1(3),
// `Array.prototype.every` method
// https://tc39.es/ecma262/#sec-array.prototype.every
every: createMethod$1(4),
// `Array.prototype.find` method
// https://tc39.es/ecma262/#sec-array.prototype.find
find: createMethod$1(5),
// `Array.prototype.findIndex` method
// https://tc39.es/ecma262/#sec-array.prototype.findIndex
findIndex: createMethod$1(6),
// `Array.prototype.filterOut` method
// https://github.com/tc39/proposal-array-filtering
filterOut: createMethod$1(7)
};
var arrayIteration_1 = arrayIteration.forEach;
var arrayIteration_2 = arrayIteration.map;
var arrayIteration_3 = arrayIteration.filter;
var arrayIteration_4 = arrayIteration.some;
var arrayIteration_5 = arrayIteration.every;
var arrayIteration_6 = arrayIteration.find;
var arrayIteration_7 = arrayIteration.findIndex;
var arrayIteration_8 = arrayIteration.filterOut;
'use strict';
var $forEach = arrayIteration.forEach;
var HIDDEN = sharedKey('hidden');
var SYMBOL = 'Symbol';
var PROTOTYPE$1 = 'prototype';
var TO_PRIMITIVE = wellKnownSymbol('toPrimitive');
var setInternalState = internalState.set;
var getInternalState = internalState.getterFor(SYMBOL);
var ObjectPrototype = Object[PROTOTYPE$1];
var $Symbol = global_1.Symbol;
var $stringify = getBuiltIn('JSON', 'stringify');
var nativeGetOwnPropertyDescriptor$1 = objectGetOwnPropertyDescriptor.f;
var nativeDefineProperty$1 = objectDefineProperty.f;
var nativeGetOwnPropertyNames$1 = objectGetOwnPropertyNamesExternal.f;
var nativePropertyIsEnumerable$1 = objectPropertyIsEnumerable.f;
var AllSymbols = shared('symbols');
var ObjectPrototypeSymbols = shared('op-symbols');
var StringToSymbolRegistry = shared('string-to-symbol-registry');
var SymbolToStringRegistry = shared('symbol-to-string-registry');
var WellKnownSymbolsStore$1 = shared('wks');
var QObject = global_1.QObject; // Don't use setters in Qt Script, https://github.com/zloirock/core-js/issues/173
var USE_SETTER = !QObject || !QObject[PROTOTYPE$1] || !QObject[PROTOTYPE$1].findChild; // fallback for old Android, https://code.google.com/p/v8/issues/detail?id=687
var setSymbolDescriptor = descriptors && fails(function () {
return objectCreate(nativeDefineProperty$1({}, 'a', {
get: function get() {
return nativeDefineProperty$1(this, 'a', {
value: 7
}).a;
}
})).a != 7;
}) ? function (O, P, Attributes) {
var ObjectPrototypeDescriptor = nativeGetOwnPropertyDescriptor$1(ObjectPrototype, P);
if (ObjectPrototypeDescriptor) delete ObjectPrototype[P];
nativeDefineProperty$1(O, P, Attributes);
if (ObjectPrototypeDescriptor && O !== ObjectPrototype) {
nativeDefineProperty$1(ObjectPrototype, P, ObjectPrototypeDescriptor);
}
} : nativeDefineProperty$1;
var wrap = function wrap(tag, description) {
var symbol = AllSymbols[tag] = objectCreate($Symbol[PROTOTYPE$1]);
setInternalState(symbol, {
type: SYMBOL,
tag: tag,
description: description
});
if (!descriptors) symbol.description = description;
return symbol;
};
var isSymbol = useSymbolAsUid ? function (it) {
return typeof it == 'symbol';
} : function (it) {
return Object(it) instanceof $Symbol;
};
var $defineProperty = function defineProperty(O, P, Attributes) {
if (O === ObjectPrototype) $defineProperty(ObjectPrototypeSymbols, P, Attributes);
anObject(O);
var key = toPrimitive(P, true);
anObject(Attributes);
if (has(AllSymbols, key)) {
if (!Attributes.enumerable) {
if (!has(O, HIDDEN)) nativeDefineProperty$1(O, HIDDEN, createPropertyDescriptor(1, {}));
O[HIDDEN][key] = true;
} else {
if (has(O, HIDDEN) && O[HIDDEN][key]) O[HIDDEN][key] = false;
Attributes = objectCreate(Attributes, {
enumerable: createPropertyDescriptor(0, false)
});
}
return setSymbolDescriptor(O, key, Attributes);
}
return nativeDefineProperty$1(O, key, Attributes);
};
var $defineProperties = function defineProperties(O, Properties) {
anObject(O);
var properties = toIndexedObject(Properties);
var keys = objectKeys(properties).concat($getOwnPropertySymbols(properties));
$forEach(keys, function (key) {
if (!descriptors || $propertyIsEnumerable.call(properties, key)) $defineProperty(O, key, properties[key]);
});
return O;
};
var $create = function create(O, Properties) {
return Properties === undefined ? objectCreate(O) : $defineProperties(objectCreate(O), Properties);
};
var $propertyIsEnumerable = function propertyIsEnumerable(V) {
var P = toPrimitive(V, true);
var enumerable = nativePropertyIsEnumerable$1.call(this, P);
if (this === ObjectPrototype && has(AllSymbols, P) && !has(ObjectPrototypeSymbols, P)) return false;
return enumerable || !has(this, P) || !has(AllSymbols, P) || has(this, HIDDEN) && this[HIDDEN][P] ? enumerable : true;
};
var $getOwnPropertyDescriptor = function getOwnPropertyDescriptor(O, P) {
var it = toIndexedObject(O);
var key = toPrimitive(P, true);
if (it === ObjectPrototype && has(AllSymbols, key) && !has(ObjectPrototypeSymbols, key)) return;
var descriptor = nativeGetOwnPropertyDescriptor$1(it, key);
if (descriptor && has(AllSymbols, key) && !(has(it, HIDDEN) && it[HIDDEN][key])) {
descriptor.enumerable = true;
}
return descriptor;
};
var $getOwnPropertyNames = function getOwnPropertyNames(O) {
var names = nativeGetOwnPropertyNames$1(toIndexedObject(O));
var result = [];
$forEach(names, function (key) {
if (!has(AllSymbols, key) && !has(hiddenKeys, key)) result.push(key);
});
return result;
};
var $getOwnPropertySymbols = function getOwnPropertySymbols(O) {
var IS_OBJECT_PROTOTYPE = O === ObjectPrototype;
var names = nativeGetOwnPropertyNames$1(IS_OBJECT_PROTOTYPE ? ObjectPrototypeSymbols : toIndexedObject(O));
var result = [];
$forEach(names, function (key) {
if (has(AllSymbols, key) && (!IS_OBJECT_PROTOTYPE || has(ObjectPrototype, key))) {
result.push(AllSymbols[key]);
}
});
return result;
}; // `Symbol` constructor
// https://tc39.es/ecma262/#sec-symbol-constructor
if (!nativeSymbol) {
$Symbol = function Symbol() {
if (this instanceof $Symbol) throw TypeError('Symbol is not a constructor');
var description = !arguments.length || arguments[0] === undefined ? undefined : String(arguments[0]);
var tag = uid(description);
var setter = function setter(value) {
if (this === ObjectPrototype) setter.call(ObjectPrototypeSymbols, value);
if (has(this, HIDDEN) && has(this[HIDDEN], tag)) this[HIDDEN][tag] = false;
setSymbolDescriptor(this, tag, createPropertyDescriptor(1, value));
};
if (descriptors && USE_SETTER) setSymbolDescriptor(ObjectPrototype, tag, {
configurable: true,
set: setter
});
return wrap(tag, description);
};
redefine($Symbol[PROTOTYPE$1], 'toString', function toString() {
return getInternalState(this).tag;
});
redefine($Symbol, 'withoutSetter', function (description) {
return wrap(uid(description), description);
});
objectPropertyIsEnumerable.f = $propertyIsEnumerable;
objectDefineProperty.f = $defineProperty;
objectGetOwnPropertyDescriptor.f = $getOwnPropertyDescriptor;
objectGetOwnPropertyNames.f = objectGetOwnPropertyNamesExternal.f = $getOwnPropertyNames;
objectGetOwnPropertySymbols.f = $getOwnPropertySymbols;
wellKnownSymbolWrapped.f = function (name) {
return wrap(wellKnownSymbol(name), name);
};
if (descriptors) {
// https://github.com/tc39/proposal-Symbol-description
nativeDefineProperty$1($Symbol[PROTOTYPE$1], 'description', {
configurable: true,
get: function description() {
return getInternalState(this).description;
}
});
if (!isPure) {
redefine(ObjectPrototype, 'propertyIsEnumerable', $propertyIsEnumerable, {
unsafe: true
});
}
}
}
_export({
global: true,
wrap: true,
forced: !nativeSymbol,
sham: !nativeSymbol
}, {
Symbol: $Symbol
});
$forEach(objectKeys(WellKnownSymbolsStore$1), function (name) {
defineWellKnownSymbol(name);
});
_export({
target: SYMBOL,
stat: true,
forced: !nativeSymbol
}, {
// `Symbol.for` method
// https://tc39.es/ecma262/#sec-symbol.for
'for': function _for(key) {
var string = String(key);
if (has(StringToSymbolRegistry, string)) return StringToSymbolRegistry[string];
var symbol = $Symbol(string);
StringToSymbolRegistry[string] = symbol;
SymbolToStringRegistry[symbol] = string;
return symbol;
},
// `Symbol.keyFor` method
// https://tc39.es/ecma262/#sec-symbol.keyfor
keyFor: function keyFor(sym) {
if (!isSymbol(sym)) throw TypeError(sym + ' is not a symbol');
if (has(SymbolToStringRegistry, sym)) return SymbolToStringRegistry[sym];
},
useSetter: function useSetter() {
USE_SETTER = true;
},
useSimple: function useSimple() {
USE_SETTER = false;
}
});
_export({
target: 'Object',
stat: true,
forced: !nativeSymbol,
sham: !descriptors
}, {
// `Object.create` method
// https://tc39.es/ecma262/#sec-object.create
create: $create,
// `Object.defineProperty` method
// https://tc39.es/ecma262/#sec-object.defineproperty
defineProperty: $defineProperty,
// `Object.defineProperties` method
// https://tc39.es/ecma262/#sec-object.defineproperties
defineProperties: $defineProperties,
// `Object.getOwnPropertyDescriptor` method
// https://tc39.es/ecma262/#sec-object.getownpropertydescriptors
getOwnPropertyDescriptor: $getOwnPropertyDescriptor
});
_export({
target: 'Object',
stat: true,
forced: !nativeSymbol
}, {
// `Object.getOwnPropertyNames` method
// https://tc39.es/ecma262/#sec-object.getownpropertynames
getOwnPropertyNames: $getOwnPropertyNames,
// `Object.getOwnPropertySymbols` method
// https://tc39.es/ecma262/#sec-object.getownpropertysymbols
getOwnPropertySymbols: $getOwnPropertySymbols
}); // Chrome 38 and 39 `Object.getOwnPropertySymbols` fails on primitives
// https://bugs.chromium.org/p/v8/issues/detail?id=3443
_export({
target: 'Object',
stat: true,
forced: fails(function () {
objectGetOwnPropertySymbols.f(1);
})
}, {
getOwnPropertySymbols: function getOwnPropertySymbols(it) {
return objectGetOwnPropertySymbols.f(toObject(it));
}
}); // `JSON.stringify` method behavior with symbols
// https://tc39.es/ecma262/#sec-json.stringify
if ($stringify) {
var FORCED_JSON_STRINGIFY = !nativeSymbol || fails(function () {
var symbol = $Symbol(); // MS Edge converts symbol values to JSON as {}
return $stringify([symbol]) != '[null]' // WebKit converts symbol values to JSON as null
|| $stringify({
a: symbol
}) != '{}' // V8 throws on boxed symbols
|| $stringify(Object(symbol)) != '{}';
});
_export({
target: 'JSON',
stat: true,
forced: FORCED_JSON_STRINGIFY
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
stringify: function stringify(it, replacer, space) {
var args = [it];
var index = 1;
var $replacer;
while (arguments.length > index) {
args.push(arguments[index++]);
}
$replacer = replacer;
if (!isObject(replacer) && it === undefined || isSymbol(it)) return; // IE8 returns string on undefined
if (!isArray(replacer)) replacer = function replacer(key, value) {
if (typeof $replacer == 'function') value = $replacer.call(this, key, value);
if (!isSymbol(value)) return value;
};
args[1] = replacer;
return $stringify.apply(null, args);
}
});
} // `Symbol.prototype[@@toPrimitive]` method
// https://tc39.es/ecma262/#sec-symbol.prototype-@@toprimitive
if (!$Symbol[PROTOTYPE$1][TO_PRIMITIVE]) {
createNonEnumerableProperty($Symbol[PROTOTYPE$1], TO_PRIMITIVE, $Symbol[PROTOTYPE$1].valueOf);
} // `Symbol.prototype[@@toStringTag]` property
// https://tc39.es/ecma262/#sec-symbol.prototype-@@tostringtag
setToStringTag($Symbol, SYMBOL);
hiddenKeys[HIDDEN] = true;
var es_symbol = {};
// https://tc39.es/ecma262/#sec-symbol.asynciterator
defineWellKnownSymbol('asyncIterator');
var es_symbol_asyncIterator = {};
// https://tc39.es/ecma262/#sec-symbol.prototype.description
'use strict';
var defineProperty$2 = objectDefineProperty.f;
var NativeSymbol = global_1.Symbol;
if (descriptors && typeof NativeSymbol == 'function' && (!('description' in NativeSymbol.prototype) || // Safari 12 bug
NativeSymbol().description !== undefined)) {
var EmptyStringDescriptionStore = {}; // wrap Symbol constructor for correct work with undefined description
var SymbolWrapper = function Symbol() {
var description = arguments.length < 1 || arguments[0] === undefined ? undefined : String(arguments[0]);
var result = this instanceof SymbolWrapper ? new NativeSymbol(description) // in Edge 13, String(Symbol(undefined)) === 'Symbol(undefined)'
: description === undefined ? NativeSymbol() : NativeSymbol(description);
if (description === '') EmptyStringDescriptionStore[result] = true;
return result;
};
copyConstructorProperties(SymbolWrapper, NativeSymbol);
var symbolPrototype = SymbolWrapper.prototype = NativeSymbol.prototype;
symbolPrototype.constructor = SymbolWrapper;
var symbolToString = symbolPrototype.toString;
var native = String(NativeSymbol('test')) == 'Symbol(test)';
var regexp = /^Symbol\((.*)\)[^)]+$/;
defineProperty$2(symbolPrototype, 'description', {
configurable: true,
get: function description() {
var symbol = isObject(this) ? this.valueOf() : this;
var string = symbolToString.call(symbol);
if (has(EmptyStringDescriptionStore, symbol)) return '';
var desc = native ? string.slice(7, -1) : string.replace(regexp, '$1');
return desc === '' ? undefined : desc;
}
});
_export({
global: true,
forced: true
}, {
Symbol: SymbolWrapper
});
}
var es_symbol_description = {};
// https://tc39.es/ecma262/#sec-symbol.hasinstance
defineWellKnownSymbol('hasInstance');
var es_symbol_hasInstance = {};
// https://tc39.es/ecma262/#sec-symbol.isconcatspreadable
defineWellKnownSymbol('isConcatSpreadable');
var es_symbol_isConcatSpreadable = {};
// https://tc39.es/ecma262/#sec-symbol.iterator
defineWellKnownSymbol('iterator');
var es_symbol_iterator = {};
// https://tc39.es/ecma262/#sec-symbol.match
defineWellKnownSymbol('match');
var es_symbol_match = {};
// https://tc39.es/ecma262/#sec-symbol.matchall
defineWellKnownSymbol('matchAll');
var es_symbol_matchAll = {};
// https://tc39.es/ecma262/#sec-symbol.replace
defineWellKnownSymbol('replace');
var es_symbol_replace = {};
// https://tc39.es/ecma262/#sec-symbol.search
defineWellKnownSymbol('search');
var es_symbol_search = {};
// https://tc39.es/ecma262/#sec-symbol.species
defineWellKnownSymbol('species');
var es_symbol_species = {};
// https://tc39.es/ecma262/#sec-symbol.split
defineWellKnownSymbol('split');
var es_symbol_split = {};
// https://tc39.es/ecma262/#sec-symbol.toprimitive
defineWellKnownSymbol('toPrimitive');
var es_symbol_toPrimitive = {};
// https://tc39.es/ecma262/#sec-symbol.tostringtag
defineWellKnownSymbol('toStringTag');
var es_symbol_toStringTag = {};
// https://tc39.es/ecma262/#sec-symbol.unscopables
defineWellKnownSymbol('unscopables');
var es_symbol_unscopables = {};
var correctPrototypeGetter = !fails(function () {
function F() {
/* empty */
}
F.prototype.constructor = null;
return Object.getPrototypeOf(new F()) !== F.prototype;
});
var IE_PROTO$1 = sharedKey('IE_PROTO');
var ObjectPrototype$1 = Object.prototype; // `Object.getPrototypeOf` method
// https://tc39.es/ecma262/#sec-object.getprototypeof
var objectGetPrototypeOf = correctPrototypeGetter ? Object.getPrototypeOf : function (O) {
O = toObject(O);
if (has(O, IE_PROTO$1)) return O[IE_PROTO$1];
if (typeof O.constructor == 'function' && O instanceof O.constructor) {
return O.constructor.prototype;
}
return O instanceof Object ? ObjectPrototype$1 : null;
};
var aPossiblePrototype = function aPossiblePrototype(it) {
if (!isObject(it) && it !== null) {
throw TypeError("Can't set " + String(it) + ' as a prototype');
}
return it;
};
/* eslint-disable no-proto -- safe */
// `Object.setPrototypeOf` method
// https://tc39.es/ecma262/#sec-object.setprototypeof
// Works with __proto__ only. Old v8 can't work with null proto objects.
var objectSetPrototypeOf = Object.setPrototypeOf || ('__proto__' in {} ? function () {
var CORRECT_SETTER = false;
var test = {};
var setter;
try {
setter = Object.getOwnPropertyDescriptor(Object.prototype, '__proto__').set;
setter.call(test, []);
CORRECT_SETTER = test instanceof Array;
} catch (error) {
/* empty */
}
return function setPrototypeOf(O, proto) {
anObject(O);
aPossiblePrototype(proto);
if (CORRECT_SETTER) setter.call(O, proto);else O.__proto__ = proto;
return O;
};
}() : undefined);
var iterators = {};
var ITERATOR = wellKnownSymbol('iterator');
var ArrayPrototype = Array.prototype; // check on default Array iterator
var isArrayIteratorMethod = function isArrayIteratorMethod(it) {
return it !== undefined && (iterators.Array === it || ArrayPrototype[ITERATOR] === it);
};
var TO_STRING_TAG$1 = wellKnownSymbol('toStringTag');
var test = {};
test[TO_STRING_TAG$1] = 'z';
var toStringTagSupport = String(test) === '[object z]';
var TO_STRING_TAG$2 = wellKnownSymbol('toStringTag'); // ES3 wrong here
var CORRECT_ARGUMENTS = classofRaw(function () {
return arguments;
}()) == 'Arguments'; // fallback for IE11 Script Access Denied error
var tryGet = function tryGet(it, key) {
try {
return it[key];
} catch (error) {
/* empty */
}
}; // getting tag from ES6+ `Object.prototype.toString`
var classof = toStringTagSupport ? classofRaw : function (it) {
var O, tag, result;
return it === undefined ? 'Undefined' : it === null ? 'Null' // @@toStringTag case
: typeof (tag = tryGet(O = Object(it), TO_STRING_TAG$2)) == 'string' ? tag // builtinTag case
: CORRECT_ARGUMENTS ? classofRaw(O) // ES3 arguments fallback
: (result = classofRaw(O)) == 'Object' && typeof O.callee == 'function' ? 'Arguments' : result;
};
var ITERATOR$1 = wellKnownSymbol('iterator');
var getIteratorMethod = function getIteratorMethod(it) {
if (it != undefined) return it[ITERATOR$1] || it['@@iterator'] || iterators[classof(it)];
};
var iteratorClose = function iteratorClose(iterator) {
var returnMethod = iterator['return'];
if (returnMethod !== undefined) {
return anObject(returnMethod.call(iterator)).value;
}
};
var Result = function Result(stopped, result) {
this.stopped = stopped;
this.result = result;
};
var iterate = function iterate(iterable, unboundFunction, options) {
var that = options && options.that;
var AS_ENTRIES = !!(options && options.AS_ENTRIES);
var IS_ITERATOR = !!(options && options.IS_ITERATOR);
var INTERRUPTED = !!(options && options.INTERRUPTED);
var fn = functionBindContext(unboundFunction, that, 1 + AS_ENTRIES + INTERRUPTED);
var iterator, iterFn, index, length, result, next, step;
var stop = function stop(condition) {
if (iterator) iteratorClose(iterator);
return new Result(true, condition);
};
var callFn = function callFn(value) {
if (AS_ENTRIES) {
anObject(value);
return INTERRUPTED ? fn(value[0], value[1], stop) : fn(value[0], value[1]);
}
return INTERRUPTED ? fn(value, stop) : fn(value);
};
if (IS_ITERATOR) {
iterator = iterable;
} else {
iterFn = getIteratorMethod(iterable);
if (typeof iterFn != 'function') throw TypeError('Target is not iterable'); // optimisation for array iterators
if (isArrayIteratorMethod(iterFn)) {
for (index = 0, length = toLength(iterable.length); length > index; index++) {
result = callFn(iterable[index]);
if (result && result instanceof Result) return result;
}
return new Result(false);
}
iterator = iterFn.call(iterable);
}
next = iterator.next;
while (!(step = next.call(iterator)).done) {
try {
result = callFn(step.value);
} catch (error) {
iteratorClose(iterator);
throw error;
}
if (typeof result == 'object' && result && result instanceof Result) return result;
}
return new Result(false);
};
'use strict';
var $AggregateError = function AggregateError(errors, message) {
var that = this;
if (!(that instanceof $AggregateError)) return new $AggregateError(errors, message);
if (objectSetPrototypeOf) {
// eslint-disable-next-line unicorn/error-message -- expected
that = objectSetPrototypeOf(new Error(undefined), objectGetPrototypeOf(that));
}
if (message !== undefined) createNonEnumerableProperty(that, 'message', String(message));
var errorsArray = [];
iterate(errors, errorsArray.push, {
that: errorsArray
});
createNonEnumerableProperty(that, 'errors', errorsArray);
return that;
};
$AggregateError.prototype = objectCreate(Error.prototype, {
constructor: createPropertyDescriptor(5, $AggregateError),
message: createPropertyDescriptor(5, ''),
name: createPropertyDescriptor(5, 'AggregateError')
}); // `AggregateError` constructor
// https://tc39.es/ecma262/#sec-aggregate-error-constructor
_export({
global: true
}, {
AggregateError: $AggregateError
});
var es_aggregateError = {};
var callWithSafeIterationClosing = function callWithSafeIterationClosing(iterator, fn, value, ENTRIES) {
try {
return ENTRIES ? fn(anObject(value)[0], value[1]) : fn(value); // 7.4.6 IteratorClose(iterator, completion)
} catch (error) {
iteratorClose(iterator);
throw error;
}
};
'use strict';
var createProperty = function createProperty(object, key, value) {
var propertyKey = toPrimitive(key);
if (propertyKey in object) objectDefineProperty.f(object, propertyKey, createPropertyDescriptor(0, value));else object[propertyKey] = value;
};
'use strict'; // `Array.from` method implementation
// https://tc39.es/ecma262/#sec-array.from
var arrayFrom = function from(arrayLike
/* , mapfn = undefined, thisArg = undefined */
) {
var O = toObject(arrayLike);
var C = typeof this == 'function' ? this : Array;
var argumentsLength = arguments.length;
var mapfn = argumentsLength > 1 ? arguments[1] : undefined;
var mapping = mapfn !== undefined;
var iteratorMethod = getIteratorMethod(O);
var index = 0;
var length, result, step, iterator, next, value;
if (mapping) mapfn = functionBindContext(mapfn, argumentsLength > 2 ? arguments[2] : undefined, 2); // if the target is not iterable or it's an array with the default iterator - use a simple case
if (iteratorMethod != undefined && !(C == Array && isArrayIteratorMethod(iteratorMethod))) {
iterator = iteratorMethod.call(O);
next = iterator.next;
result = new C();
for (; !(step = next.call(iterator)).done; index++) {
value = mapping ? callWithSafeIterationClosing(iterator, mapfn, [step.value, index], true) : step.value;
createProperty(result, index, value);
}
} else {
length = toLength(O.length);
result = new C(length);
for (; length > index; index++) {
value = mapping ? mapfn(O[index], index) : O[index];
createProperty(result, index, value);
}
}
result.length = index;
return result;
};
var ITERATOR$2 = wellKnownSymbol('iterator');
var SAFE_CLOSING = false;
try {
var called = 0;
var iteratorWithReturn = {
next: function next() {
return {
done: !!called++
};
},
'return': function _return() {
SAFE_CLOSING = true;
}
};
iteratorWithReturn[ITERATOR$2] = function () {
return this;
}; // eslint-disable-next-line no-throw-literal -- required for testing
Array.from(iteratorWithReturn, function () {
throw 2;
});
} catch (error) {
/* empty */
}
var checkCorrectnessOfIteration = function checkCorrectnessOfIteration(exec, SKIP_CLOSING) {
if (!SKIP_CLOSING && !SAFE_CLOSING) return false;
var ITERATION_SUPPORT = false;
try {
var object = {};
object[ITERATOR$2] = function () {
return {
next: function next() {
return {
done: ITERATION_SUPPORT = true
};
}
};
};
exec(object);
} catch (error) {
/* empty */
}
return ITERATION_SUPPORT;
};
var INCORRECT_ITERATION = !checkCorrectnessOfIteration(function (iterable) {
Array.from(iterable);
}); // `Array.from` method
// https://tc39.es/ecma262/#sec-array.from
_export({
target: 'Array',
stat: true,
forced: INCORRECT_ITERATION
}, {
from: arrayFrom
});
var es_array_from = {};
// https://tc39.es/ecma262/#sec-array.isarray
_export({
target: 'Array',
stat: true
}, {
isArray: isArray
});
var es_array_isArray = {};
'use strict';
var ISNT_GENERIC = fails(function () {
function F() {
/* empty */
}
return !(Array.of.call(F) instanceof F);
}); // `Array.of` method
// https://tc39.es/ecma262/#sec-array.of
// WebKit Array.of isn't generic
_export({
target: 'Array',
stat: true,
forced: ISNT_GENERIC
}, {
of: function of()
/* ...args */
{
var index = 0;
var argumentsLength = arguments.length;
var result = new (typeof this == 'function' ? this : Array)(argumentsLength);
while (argumentsLength > index) {
createProperty(result, index, arguments[index++]);
}
result.length = argumentsLength;
return result;
}
});
var es_array_of = {};
var SPECIES$1 = wellKnownSymbol('species');
var arrayMethodHasSpeciesSupport = function arrayMethodHasSpeciesSupport(METHOD_NAME) {
// We can't use this feature detection in V8 since it causes
// deoptimization and serious performance degradation
// https://github.com/zloirock/core-js/issues/677
return engineV8Version >= 51 || !fails(function () {
var array = [];
var constructor = array.constructor = {};
constructor[SPECIES$1] = function () {
return {
foo: 1
};
};
return array[METHOD_NAME](Boolean).foo !== 1;
});
};
'use strict';
var IS_CONCAT_SPREADABLE = wellKnownSymbol('isConcatSpreadable');
var MAX_SAFE_INTEGER = 0x1FFFFFFFFFFFFF;
var MAXIMUM_ALLOWED_INDEX_EXCEEDED = 'Maximum allowed index exceeded'; // We can't use this feature detection in V8 since it causes
// deoptimization and serious performance degradation
// https://github.com/zloirock/core-js/issues/679
var IS_CONCAT_SPREADABLE_SUPPORT = engineV8Version >= 51 || !fails(function () {
var array = [];
array[IS_CONCAT_SPREADABLE] = false;
return array.concat()[0] !== array;
});
var SPECIES_SUPPORT = arrayMethodHasSpeciesSupport('concat');
var isConcatSpreadable = function isConcatSpreadable(O) {
if (!isObject(O)) return false;
var spreadable = O[IS_CONCAT_SPREADABLE];
return spreadable !== undefined ? !!spreadable : isArray(O);
};
var FORCED = !IS_CONCAT_SPREADABLE_SUPPORT || !SPECIES_SUPPORT; // `Array.prototype.concat` method
// https://tc39.es/ecma262/#sec-array.prototype.concat
// with adding support of @@isConcatSpreadable and @@species
_export({
target: 'Array',
proto: true,
forced: FORCED
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
concat: function concat(arg) {
var O = toObject(this);
var A = arraySpeciesCreate(O, 0);
var n = 0;
var i, k, length, len, E;
for (i = -1, length = arguments.length; i < length; i++) {
E = i === -1 ? O : arguments[i];
if (isConcatSpreadable(E)) {
len = toLength(E.length);
if (n + len > MAX_SAFE_INTEGER) throw TypeError(MAXIMUM_ALLOWED_INDEX_EXCEEDED);
for (k = 0; k < len; k++, n++) {
if (k in E) createProperty(A, n, E[k]);
}
} else {
if (n >= MAX_SAFE_INTEGER) throw TypeError(MAXIMUM_ALLOWED_INDEX_EXCEEDED);
createProperty(A, n++, E);
}
}
A.length = n;
return A;
}
});
var es_array_concat = {};
'use strict';
var min$2 = Math.min; // `Array.prototype.copyWithin` method implementation
// https://tc39.es/ecma262/#sec-array.prototype.copywithin
var arrayCopyWithin = [].copyWithin || function copyWithin(target
/* = 0 */
, start
/* = 0, end = @length */
) {
var O = toObject(this);
var len = toLength(O.length);
var to = toAbsoluteIndex(target, len);
var from = toAbsoluteIndex(start, len);
var end = arguments.length > 2 ? arguments[2] : undefined;
var count = min$2((end === undefined ? len : toAbsoluteIndex(end, len)) - from, len - to);
var inc = 1;
if (from < to && to < from + count) {
inc = -1;
from += count - 1;
to += count - 1;
}
while (count-- > 0) {
if (from in O) O[to] = O[from];else delete O[to];
to += inc;
from += inc;
}
return O;
};
var UNSCOPABLES = wellKnownSymbol('unscopables');
var ArrayPrototype$1 = Array.prototype; // Array.prototype[@@unscopables]
// https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
if (ArrayPrototype$1[UNSCOPABLES] == undefined) {
objectDefineProperty.f(ArrayPrototype$1, UNSCOPABLES, {
configurable: true,
value: objectCreate(null)
});
} // add a key to Array.prototype[@@unscopables]
var addToUnscopables = function addToUnscopables(key) {
ArrayPrototype$1[UNSCOPABLES][key] = true;
};
// https://tc39.es/ecma262/#sec-array.prototype.copywithin
_export({
target: 'Array',
proto: true
}, {
copyWithin: arrayCopyWithin
}); // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('copyWithin');
var es_array_copyWithin = {};
'use strict';
var arrayMethodIsStrict = function arrayMethodIsStrict(METHOD_NAME, argument) {
var method = [][METHOD_NAME];
return !!method && fails(function () {
// eslint-disable-next-line no-useless-call,no-throw-literal -- required for testing
method.call(null, argument || function () {
throw 1;
}, 1);
});
};
'use strict';
var $every = arrayIteration.every;
var STRICT_METHOD = arrayMethodIsStrict('every'); // `Array.prototype.every` method
// https://tc39.es/ecma262/#sec-array.prototype.every
_export({
target: 'Array',
proto: true,
forced: !STRICT_METHOD
}, {
every: function every(callbackfn
/* , thisArg */
) {
return $every(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_every = {};
'use strict'; // `Array.prototype.fill` method implementation
// https://tc39.es/ecma262/#sec-array.prototype.fill
var arrayFill = function fill(value
/* , start = 0, end = @length */
) {
var O = toObject(this);
var length = toLength(O.length);
var argumentsLength = arguments.length;
var index = toAbsoluteIndex(argumentsLength > 1 ? arguments[1] : undefined, length);
var end = argumentsLength > 2 ? arguments[2] : undefined;
var endPos = end === undefined ? length : toAbsoluteIndex(end, length);
while (endPos > index) {
O[index++] = value;
}
return O;
};
// https://tc39.es/ecma262/#sec-array.prototype.fill
_export({
target: 'Array',
proto: true
}, {
fill: arrayFill
}); // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('fill');
var es_array_fill = {};
'use strict';
var $filter = arrayIteration.filter;
var HAS_SPECIES_SUPPORT = arrayMethodHasSpeciesSupport('filter'); // `Array.prototype.filter` method
// https://tc39.es/ecma262/#sec-array.prototype.filter
// with adding support of @@species
_export({
target: 'Array',
proto: true,
forced: !HAS_SPECIES_SUPPORT
}, {
filter: function filter(callbackfn
/* , thisArg */
) {
return $filter(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_filter = {};
'use strict';
var $find = arrayIteration.find;
var FIND = 'find';
var SKIPS_HOLES = true; // Shouldn't skip holes
if (FIND in []) Array(1)[FIND](function () {
SKIPS_HOLES = false;
}); // `Array.prototype.find` method
// https://tc39.es/ecma262/#sec-array.prototype.find
_export({
target: 'Array',
proto: true,
forced: SKIPS_HOLES
}, {
find: function find(callbackfn
/* , that = undefined */
) {
return $find(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
}); // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables(FIND);
var es_array_find = {};
'use strict';
var $findIndex = arrayIteration.findIndex;
var FIND_INDEX = 'findIndex';
var SKIPS_HOLES$1 = true; // Shouldn't skip holes
if (FIND_INDEX in []) Array(1)[FIND_INDEX](function () {
SKIPS_HOLES$1 = false;
}); // `Array.prototype.findIndex` method
// https://tc39.es/ecma262/#sec-array.prototype.findindex
_export({
target: 'Array',
proto: true,
forced: SKIPS_HOLES$1
}, {
findIndex: function findIndex(callbackfn
/* , that = undefined */
) {
return $findIndex(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
}); // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables(FIND_INDEX);
var es_array_findIndex = {};
'use strict'; // `FlattenIntoArray` abstract operation
// https://tc39.github.io/proposal-flatMap/#sec-FlattenIntoArray
var flattenIntoArray = function flattenIntoArray(target, original, source, sourceLen, start, depth, mapper, thisArg) {
var targetIndex = start;
var sourceIndex = 0;
var mapFn = mapper ? functionBindContext(mapper, thisArg, 3) : false;
var element;
while (sourceIndex < sourceLen) {
if (sourceIndex in source) {
element = mapFn ? mapFn(source[sourceIndex], sourceIndex, original) : source[sourceIndex];
if (depth > 0 && isArray(element)) {
targetIndex = flattenIntoArray(target, original, element, toLength(element.length), targetIndex, depth - 1) - 1;
} else {
if (targetIndex >= 0x1FFFFFFFFFFFFF) throw TypeError('Exceed the acceptable array length');
target[targetIndex] = element;
}
targetIndex++;
}
sourceIndex++;
}
return targetIndex;
};
var flattenIntoArray_1 = flattenIntoArray;
'use strict'; // `Array.prototype.flat` method
// https://tc39.es/ecma262/#sec-array.prototype.flat
_export({
target: 'Array',
proto: true
}, {
flat: function flat()
/* depthArg = 1 */
{
var depthArg = arguments.length ? arguments[0] : undefined;
var O = toObject(this);
var sourceLen = toLength(O.length);
var A = arraySpeciesCreate(O, 0);
A.length = flattenIntoArray_1(A, O, O, sourceLen, 0, depthArg === undefined ? 1 : toInteger(depthArg));
return A;
}
});
var es_array_flat = {};
'use strict'; // `Array.prototype.flatMap` method
// https://tc39.es/ecma262/#sec-array.prototype.flatmap
_export({
target: 'Array',
proto: true
}, {
flatMap: function flatMap(callbackfn
/* , thisArg */
) {
var O = toObject(this);
var sourceLen = toLength(O.length);
var A;
aFunction$1(callbackfn);
A = arraySpeciesCreate(O, 0);
A.length = flattenIntoArray_1(A, O, O, sourceLen, 0, 1, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
return A;
}
});
var es_array_flatMap = {};
'use strict';
var $forEach$1 = arrayIteration.forEach;
var STRICT_METHOD$1 = arrayMethodIsStrict('forEach'); // `Array.prototype.forEach` method implementation
// https://tc39.es/ecma262/#sec-array.prototype.foreach
var arrayForEach = !STRICT_METHOD$1 ? function forEach(callbackfn
/* , thisArg */
) {
return $forEach$1(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
} : [].forEach;
'use strict'; // `Array.prototype.forEach` method
// https://tc39.es/ecma262/#sec-array.prototype.foreach
_export({
target: 'Array',
proto: true,
forced: [].forEach != arrayForEach
}, {
forEach: arrayForEach
});
var es_array_forEach = {};
'use strict';
var $includes = arrayIncludes.includes; // `Array.prototype.includes` method
// https://tc39.es/ecma262/#sec-array.prototype.includes
_export({
target: 'Array',
proto: true
}, {
includes: function includes(el
/* , fromIndex = 0 */
) {
return $includes(this, el, arguments.length > 1 ? arguments[1] : undefined);
}
}); // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('includes');
var es_array_includes = {};
'use strict';
var $indexOf = arrayIncludes.indexOf;
var nativeIndexOf = [].indexOf;
var NEGATIVE_ZERO = !!nativeIndexOf && 1 / [1].indexOf(1, -0) < 0;
var STRICT_METHOD$2 = arrayMethodIsStrict('indexOf'); // `Array.prototype.indexOf` method
// https://tc39.es/ecma262/#sec-array.prototype.indexof
_export({
target: 'Array',
proto: true,
forced: NEGATIVE_ZERO || !STRICT_METHOD$2
}, {
indexOf: function indexOf(searchElement
/* , fromIndex = 0 */
) {
return NEGATIVE_ZERO // convert -0 to +0
? nativeIndexOf.apply(this, arguments) || 0 : $indexOf(this, searchElement, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_indexOf = {};
'use strict';
var nativeJoin = [].join;
var ES3_STRINGS = indexedObject != Object;
var STRICT_METHOD$3 = arrayMethodIsStrict('join', ','); // `Array.prototype.join` method
// https://tc39.es/ecma262/#sec-array.prototype.join
_export({
target: 'Array',
proto: true,
forced: ES3_STRINGS || !STRICT_METHOD$3
}, {
join: function join(separator) {
return nativeJoin.call(toIndexedObject(this), separator === undefined ? ',' : separator);
}
});
var es_array_join = {};
'use strict';
var min$3 = Math.min;
var nativeLastIndexOf = [].lastIndexOf;
var NEGATIVE_ZERO$1 = !!nativeLastIndexOf && 1 / [1].lastIndexOf(1, -0) < 0;
var STRICT_METHOD$4 = arrayMethodIsStrict('lastIndexOf');
var FORCED$1 = NEGATIVE_ZERO$1 || !STRICT_METHOD$4; // `Array.prototype.lastIndexOf` method implementation
// https://tc39.es/ecma262/#sec-array.prototype.lastindexof
var arrayLastIndexOf = FORCED$1 ? function lastIndexOf(searchElement
/* , fromIndex = @[*-1] */
) {
// convert -0 to +0
if (NEGATIVE_ZERO$1) return nativeLastIndexOf.apply(this, arguments) || 0;
var O = toIndexedObject(this);
var length = toLength(O.length);
var index = length - 1;
if (arguments.length > 1) index = min$3(index, toInteger(arguments[1]));
if (index < 0) index = length + index;
for (; index >= 0; index--) {
if (index in O && O[index] === searchElement) return index || 0;
}
return -1;
} : nativeLastIndexOf;
// https://tc39.es/ecma262/#sec-array.prototype.lastindexof
_export({
target: 'Array',
proto: true,
forced: arrayLastIndexOf !== [].lastIndexOf
}, {
lastIndexOf: arrayLastIndexOf
});
var es_array_lastIndexOf = {};
'use strict';
var $map = arrayIteration.map;
var HAS_SPECIES_SUPPORT$1 = arrayMethodHasSpeciesSupport('map'); // `Array.prototype.map` method
// https://tc39.es/ecma262/#sec-array.prototype.map
// with adding support of @@species
_export({
target: 'Array',
proto: true,
forced: !HAS_SPECIES_SUPPORT$1
}, {
map: function map(callbackfn
/* , thisArg */
) {
return $map(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_map = {};
var createMethod$2 = function createMethod(IS_RIGHT) {
return function (that, callbackfn, argumentsLength, memo) {
aFunction$1(callbackfn);
var O = toObject(that);
var self = indexedObject(O);
var length = toLength(O.length);
var index = IS_RIGHT ? length - 1 : 0;
var i = IS_RIGHT ? -1 : 1;
if (argumentsLength < 2) while (true) {
if (index in self) {
memo = self[index];
index += i;
break;
}
index += i;
if (IS_RIGHT ? index < 0 : length <= index) {
throw TypeError('Reduce of empty array with no initial value');
}
}
for (; IS_RIGHT ? index >= 0 : length > index; index += i) {
if (index in self) {
memo = callbackfn(memo, self[index], index, O);
}
}
return memo;
};
};
var arrayReduce = {
// `Array.prototype.reduce` method
// https://tc39.es/ecma262/#sec-array.prototype.reduce
left: createMethod$2(false),
// `Array.prototype.reduceRight` method
// https://tc39.es/ecma262/#sec-array.prototype.reduceright
right: createMethod$2(true)
};
var arrayReduce_1 = arrayReduce.left;
var arrayReduce_2 = arrayReduce.right;
'use strict';
var $reduce = arrayReduce.left;
var STRICT_METHOD$5 = arrayMethodIsStrict('reduce'); // Chrome 80-82 has a critical bug
// https://bugs.chromium.org/p/chromium/issues/detail?id=1049982
var CHROME_BUG = !engineIsNode && engineV8Version > 79 && engineV8Version < 83; // `Array.prototype.reduce` method
// https://tc39.es/ecma262/#sec-array.prototype.reduce
_export({
target: 'Array',
proto: true,
forced: !STRICT_METHOD$5 || CHROME_BUG
}, {
reduce: function reduce(callbackfn
/* , initialValue */
) {
return $reduce(this, callbackfn, arguments.length, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_reduce = {};
'use strict';
var $reduceRight = arrayReduce.right;
var STRICT_METHOD$6 = arrayMethodIsStrict('reduceRight'); // Chrome 80-82 has a critical bug
// https://bugs.chromium.org/p/chromium/issues/detail?id=1049982
var CHROME_BUG$1 = !engineIsNode && engineV8Version > 79 && engineV8Version < 83; // `Array.prototype.reduceRight` method
// https://tc39.es/ecma262/#sec-array.prototype.reduceright
_export({
target: 'Array',
proto: true,
forced: !STRICT_METHOD$6 || CHROME_BUG$1
}, {
reduceRight: function reduceRight(callbackfn
/* , initialValue */
) {
return $reduceRight(this, callbackfn, arguments.length, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_reduceRight = {};
'use strict';
var nativeReverse = [].reverse;
var test$1 = [1, 2]; // `Array.prototype.reverse` method
// https://tc39.es/ecma262/#sec-array.prototype.reverse
// fix for Safari 12.0 bug
// https://bugs.webkit.org/show_bug.cgi?id=188794
_export({
target: 'Array',
proto: true,
forced: String(test$1) === String(test$1.reverse())
}, {
reverse: function reverse() {
// eslint-disable-next-line no-self-assign -- dirty hack
if (isArray(this)) this.length = this.length;
return nativeReverse.call(this);
}
});
var es_array_reverse = {};
'use strict';
var HAS_SPECIES_SUPPORT$2 = arrayMethodHasSpeciesSupport('slice');
var SPECIES$2 = wellKnownSymbol('species');
var nativeSlice = [].slice;
var max$1 = Math.max; // `Array.prototype.slice` method
// https://tc39.es/ecma262/#sec-array.prototype.slice
// fallback for not array-like ES3 strings and DOM objects
_export({
target: 'Array',
proto: true,
forced: !HAS_SPECIES_SUPPORT$2
}, {
slice: function slice(start, end) {
var O = toIndexedObject(this);
var length = toLength(O.length);
var k = toAbsoluteIndex(start, length);
var fin = toAbsoluteIndex(end === undefined ? length : end, length); // inline `ArraySpeciesCreate` for usage native `Array#slice` where it's possible
var Constructor, result, n;
if (isArray(O)) {
Constructor = O.constructor; // cross-realm fallback
if (typeof Constructor == 'function' && (Constructor === Array || isArray(Constructor.prototype))) {
Constructor = undefined;
} else if (isObject(Constructor)) {
Constructor = Constructor[SPECIES$2];
if (Constructor === null) Constructor = undefined;
}
if (Constructor === Array || Constructor === undefined) {
return nativeSlice.call(O, k, fin);
}
}
result = new (Constructor === undefined ? Array : Constructor)(max$1(fin - k, 0));
for (n = 0; k < fin; k++, n++) {
if (k in O) createProperty(result, n, O[k]);
}
result.length = n;
return result;
}
});
var es_array_slice = {};
'use strict';
var $some = arrayIteration.some;
var STRICT_METHOD$7 = arrayMethodIsStrict('some'); // `Array.prototype.some` method
// https://tc39.es/ecma262/#sec-array.prototype.some
_export({
target: 'Array',
proto: true,
forced: !STRICT_METHOD$7
}, {
some: function some(callbackfn
/* , thisArg */
) {
return $some(this, callbackfn, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_array_some = {};
'use strict';
var test$2 = [];
var nativeSort = test$2.sort; // IE8-
var FAILS_ON_UNDEFINED = fails(function () {
test$2.sort(undefined);
}); // V8 bug
var FAILS_ON_NULL = fails(function () {
test$2.sort(null);
}); // Old WebKit
var STRICT_METHOD$8 = arrayMethodIsStrict('sort');
var FORCED$2 = FAILS_ON_UNDEFINED || !FAILS_ON_NULL || !STRICT_METHOD$8; // `Array.prototype.sort` method
// https://tc39.es/ecma262/#sec-array.prototype.sort
_export({
target: 'Array',
proto: true,
forced: FORCED$2
}, {
sort: function sort(comparefn) {
return comparefn === undefined ? nativeSort.call(toObject(this)) : nativeSort.call(toObject(this), aFunction$1(comparefn));
}
});
var es_array_sort = {};
'use strict';
var HAS_SPECIES_SUPPORT$3 = arrayMethodHasSpeciesSupport('splice');
var max$2 = Math.max;
var min$4 = Math.min;
var MAX_SAFE_INTEGER$1 = 0x1FFFFFFFFFFFFF;
var MAXIMUM_ALLOWED_LENGTH_EXCEEDED = 'Maximum allowed length exceeded'; // `Array.prototype.splice` method
// https://tc39.es/ecma262/#sec-array.prototype.splice
// with adding support of @@species
_export({
target: 'Array',
proto: true,
forced: !HAS_SPECIES_SUPPORT$3
}, {
splice: function splice(start, deleteCount
/* , ...items */
) {
var O = toObject(this);
var len = toLength(O.length);
var actualStart = toAbsoluteIndex(start, len);
var argumentsLength = arguments.length;
var insertCount, actualDeleteCount, A, k, from, to;
if (argumentsLength === 0) {
insertCount = actualDeleteCount = 0;
} else if (argumentsLength === 1) {
insertCount = 0;
actualDeleteCount = len - actualStart;
} else {
insertCount = argumentsLength - 2;
actualDeleteCount = min$4(max$2(toInteger(deleteCount), 0), len - actualStart);
}
if (len + insertCount - actualDeleteCount > MAX_SAFE_INTEGER$1) {
throw TypeError(MAXIMUM_ALLOWED_LENGTH_EXCEEDED);
}
A = arraySpeciesCreate(O, actualDeleteCount);
for (k = 0; k < actualDeleteCount; k++) {
from = actualStart + k;
if (from in O) createProperty(A, k, O[from]);
}
A.length = actualDeleteCount;
if (insertCount < actualDeleteCount) {
for (k = actualStart; k < len - actualDeleteCount; k++) {
from = k + actualDeleteCount;
to = k + insertCount;
if (from in O) O[to] = O[from];else delete O[to];
}
for (k = len; k > len - actualDeleteCount + insertCount; k--) {
delete O[k - 1];
}
} else if (insertCount > actualDeleteCount) {
for (k = len - actualDeleteCount; k > actualStart; k--) {
from = k + actualDeleteCount - 1;
to = k + insertCount - 1;
if (from in O) O[to] = O[from];else delete O[to];
}
}
for (k = 0; k < insertCount; k++) {
O[k + actualStart] = arguments[k + 2];
}
O.length = len - actualDeleteCount + insertCount;
return A;
}
});
var es_array_splice = {};
'use strict';
var SPECIES$3 = wellKnownSymbol('species');
var setSpecies = function setSpecies(CONSTRUCTOR_NAME) {
var Constructor = getBuiltIn(CONSTRUCTOR_NAME);
var defineProperty = objectDefineProperty.f;
if (descriptors && Constructor && !Constructor[SPECIES$3]) {
defineProperty(Constructor, SPECIES$3, {
configurable: true,
get: function get() {
return this;
}
});
}
};
// https://tc39.es/ecma262/#sec-get-array-@@species
setSpecies('Array');
var es_array_species = {};
// in popular engines, so it's moved to a separate module
// https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('flat');
var es_array_unscopables_flat = {};
// in popular engines, so it's moved to a separate module
// https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('flatMap');
var es_array_unscopables_flatMap = {};
'use strict';
var ITERATOR$3 = wellKnownSymbol('iterator');
var BUGGY_SAFARI_ITERATORS = false;
var returnThis = function returnThis() {
return this;
}; // `%IteratorPrototype%` object
// https://tc39.es/ecma262/#sec-%iteratorprototype%-object
var IteratorPrototype, PrototypeOfArrayIteratorPrototype, arrayIterator;
if ([].keys) {
arrayIterator = [].keys(); // Safari 8 has buggy iterators w/o `next`
if (!('next' in arrayIterator)) BUGGY_SAFARI_ITERATORS = true;else {
PrototypeOfArrayIteratorPrototype = objectGetPrototypeOf(objectGetPrototypeOf(arrayIterator));
if (PrototypeOfArrayIteratorPrototype !== Object.prototype) IteratorPrototype = PrototypeOfArrayIteratorPrototype;
}
}
var NEW_ITERATOR_PROTOTYPE = IteratorPrototype == undefined || fails(function () {
var test = {}; // FF44- legacy iterators case
return IteratorPrototype[ITERATOR$3].call(test) !== test;
});
if (NEW_ITERATOR_PROTOTYPE) IteratorPrototype = {}; // 25.1.2.1.1 %IteratorPrototype%[@@iterator]()
if ((!isPure || NEW_ITERATOR_PROTOTYPE) && !has(IteratorPrototype, ITERATOR$3)) {
createNonEnumerableProperty(IteratorPrototype, ITERATOR$3, returnThis);
}
var iteratorsCore = {
IteratorPrototype: IteratorPrototype,
BUGGY_SAFARI_ITERATORS: BUGGY_SAFARI_ITERATORS
};
var iteratorsCore_1 = iteratorsCore.IteratorPrototype;
var iteratorsCore_2 = iteratorsCore.BUGGY_SAFARI_ITERATORS;
'use strict';
var IteratorPrototype$1 = iteratorsCore.IteratorPrototype;
var returnThis$1 = function returnThis() {
return this;
};
var createIteratorConstructor = function createIteratorConstructor(IteratorConstructor, NAME, next) {
var TO_STRING_TAG = NAME + ' Iterator';
IteratorConstructor.prototype = objectCreate(IteratorPrototype$1, {
next: createPropertyDescriptor(1, next)
});
setToStringTag(IteratorConstructor, TO_STRING_TAG, false, true);
iterators[TO_STRING_TAG] = returnThis$1;
return IteratorConstructor;
};
'use strict';
var IteratorPrototype$2 = iteratorsCore.IteratorPrototype;
var BUGGY_SAFARI_ITERATORS$1 = iteratorsCore.BUGGY_SAFARI_ITERATORS;
var ITERATOR$4 = wellKnownSymbol('iterator');
var KEYS = 'keys';
var VALUES = 'values';
var ENTRIES = 'entries';
var returnThis$2 = function returnThis() {
return this;
};
var defineIterator = function defineIterator(Iterable, NAME, IteratorConstructor, next, DEFAULT, IS_SET, FORCED) {
createIteratorConstructor(IteratorConstructor, NAME, next);
var getIterationMethod = function getIterationMethod(KIND) {
if (KIND === DEFAULT && defaultIterator) return defaultIterator;
if (!BUGGY_SAFARI_ITERATORS$1 && KIND in IterablePrototype) return IterablePrototype[KIND];
switch (KIND) {
case KEYS:
return function keys() {
return new IteratorConstructor(this, KIND);
};
case VALUES:
return function values() {
return new IteratorConstructor(this, KIND);
};
case ENTRIES:
return function entries() {
return new IteratorConstructor(this, KIND);
};
}
return function () {
return new IteratorConstructor(this);
};
};
var TO_STRING_TAG = NAME + ' Iterator';
var INCORRECT_VALUES_NAME = false;
var IterablePrototype = Iterable.prototype;
var nativeIterator = IterablePrototype[ITERATOR$4] || IterablePrototype['@@iterator'] || DEFAULT && IterablePrototype[DEFAULT];
var defaultIterator = !BUGGY_SAFARI_ITERATORS$1 && nativeIterator || getIterationMethod(DEFAULT);
var anyNativeIterator = NAME == 'Array' ? IterablePrototype.entries || nativeIterator : nativeIterator;
var CurrentIteratorPrototype, methods, KEY; // fix native
if (anyNativeIterator) {
CurrentIteratorPrototype = objectGetPrototypeOf(anyNativeIterator.call(new Iterable()));
if (IteratorPrototype$2 !== Object.prototype && CurrentIteratorPrototype.next) {
if (!isPure && objectGetPrototypeOf(CurrentIteratorPrototype) !== IteratorPrototype$2) {
if (objectSetPrototypeOf) {
objectSetPrototypeOf(CurrentIteratorPrototype, IteratorPrototype$2);
} else if (typeof CurrentIteratorPrototype[ITERATOR$4] != 'function') {
createNonEnumerableProperty(CurrentIteratorPrototype, ITERATOR$4, returnThis$2);
}
} // Set @@toStringTag to native iterators
setToStringTag(CurrentIteratorPrototype, TO_STRING_TAG, true, true);
if (isPure) iterators[TO_STRING_TAG] = returnThis$2;
}
} // fix Array#{values, @@iterator}.name in V8 / FF
if (DEFAULT == VALUES && nativeIterator && nativeIterator.name !== VALUES) {
INCORRECT_VALUES_NAME = true;
defaultIterator = function values() {
return nativeIterator.call(this);
};
} // define iterator
if ((!isPure || FORCED) && IterablePrototype[ITERATOR$4] !== defaultIterator) {
createNonEnumerableProperty(IterablePrototype, ITERATOR$4, defaultIterator);
}
iterators[NAME] = defaultIterator; // export additional methods
if (DEFAULT) {
methods = {
values: getIterationMethod(VALUES),
keys: IS_SET ? defaultIterator : getIterationMethod(KEYS),
entries: getIterationMethod(ENTRIES)
};
if (FORCED) for (KEY in methods) {
if (BUGGY_SAFARI_ITERATORS$1 || INCORRECT_VALUES_NAME || !(KEY in IterablePrototype)) {
redefine(IterablePrototype, KEY, methods[KEY]);
}
} else _export({
target: NAME,
proto: true,
forced: BUGGY_SAFARI_ITERATORS$1 || INCORRECT_VALUES_NAME
}, methods);
}
return methods;
};
'use strict';
var ARRAY_ITERATOR = 'Array Iterator';
var setInternalState$1 = internalState.set;
var getInternalState$1 = internalState.getterFor(ARRAY_ITERATOR); // `Array.prototype.entries` method
// https://tc39.es/ecma262/#sec-array.prototype.entries
// `Array.prototype.keys` method
// https://tc39.es/ecma262/#sec-array.prototype.keys
// `Array.prototype.values` method
// https://tc39.es/ecma262/#sec-array.prototype.values
// `Array.prototype[@@iterator]` method
// https://tc39.es/ecma262/#sec-array.prototype-@@iterator
// `CreateArrayIterator` internal method
// https://tc39.es/ecma262/#sec-createarrayiterator
var es_array_iterator = defineIterator(Array, 'Array', function (iterated, kind) {
setInternalState$1(this, {
type: ARRAY_ITERATOR,
target: toIndexedObject(iterated),
// target
index: 0,
// next index
kind: kind // kind
}); // `%ArrayIteratorPrototype%.next` method
// https://tc39.es/ecma262/#sec-%arrayiteratorprototype%.next
}, function () {
var state = getInternalState$1(this);
var target = state.target;
var kind = state.kind;
var index = state.index++;
if (!target || index >= target.length) {
state.target = undefined;
return {
value: undefined,
done: true
};
}
if (kind == 'keys') return {
value: index,
done: false
};
if (kind == 'values') return {
value: target[index],
done: false
};
return {
value: [index, target[index]],
done: false
};
}, 'values'); // argumentsList[@@iterator] is %ArrayProto_values%
// https://tc39.es/ecma262/#sec-createunmappedargumentsobject
// https://tc39.es/ecma262/#sec-createmappedargumentsobject
iterators.Arguments = iterators.Array; // https://tc39.es/ecma262/#sec-array.prototype-@@unscopables
addToUnscopables('keys');
addToUnscopables('values');
addToUnscopables('entries');
'use strict';
var slice = [].slice;
var factories = {};
var construct = function construct(C, argsLength, args) {
if (!(argsLength in factories)) {
for (var list = [], i = 0; i < argsLength; i++) {
list[i] = 'a[' + i + ']';
} // eslint-disable-next-line no-new-func -- we have no proper alternatives, IE8- only
factories[argsLength] = Function('C,a', 'return new C(' + list.join(',') + ')');
}
return factories[argsLength](C, args);
}; // `Function.prototype.bind` method implementation
// https://tc39.es/ecma262/#sec-function.prototype.bind
var functionBind = Function.bind || function bind(that
/* , ...args */
) {
var fn = aFunction$1(this);
var partArgs = slice.call(arguments, 1);
var boundFunction = function bound()
/* args... */
{
var args = partArgs.concat(slice.call(arguments));
return this instanceof boundFunction ? construct(fn, args.length, args) : fn.apply(that, args);
};
if (isObject(fn.prototype)) boundFunction.prototype = fn.prototype;
return boundFunction;
};
// https://tc39.es/ecma262/#sec-function.prototype.bind
_export({
target: 'Function',
proto: true
}, {
bind: functionBind
});
var es_function_bind = {};
var defineProperty$3 = objectDefineProperty.f;
var FunctionPrototype = Function.prototype;
var FunctionPrototypeToString = FunctionPrototype.toString;
var nameRE = /^\s*function ([^ (]*)/;
var NAME = 'name'; // Function instances `.name` property
// https://tc39.es/ecma262/#sec-function-instances-name
if (descriptors && !(NAME in FunctionPrototype)) {
defineProperty$3(FunctionPrototype, NAME, {
configurable: true,
get: function get() {
try {
return FunctionPrototypeToString.call(this).match(nameRE)[1];
} catch (error) {
return '';
}
}
});
}
var es_function_name = {};
'use strict';
var HAS_INSTANCE = wellKnownSymbol('hasInstance');
var FunctionPrototype$1 = Function.prototype; // `Function.prototype[@@hasInstance]` method
// https://tc39.es/ecma262/#sec-function.prototype-@@hasinstance
if (!(HAS_INSTANCE in FunctionPrototype$1)) {
objectDefineProperty.f(FunctionPrototype$1, HAS_INSTANCE, {
value: function value(O) {
if (typeof this != 'function' || !isObject(O)) return false;
if (!isObject(this.prototype)) return O instanceof this; // for environment w/o native `@@hasInstance` logic enough `instanceof`, but add this:
while (O = objectGetPrototypeOf(O)) {
if (this.prototype === O) return true;
}
return false;
}
});
}
var es_function_hasInstance = {};
// https://tc39.es/ecma262/#sec-globalthis
_export({
global: true
}, {
globalThis: global_1
});
var es_globalThis = {};
'use strict';
var nativeAssign = Object.assign;
var defineProperty$4 = Object.defineProperty; // `Object.assign` method
// https://tc39.es/ecma262/#sec-object.assign
var objectAssign = !nativeAssign || fails(function () {
// should have correct order of operations (Edge bug)
if (descriptors && nativeAssign({
b: 1
}, nativeAssign(defineProperty$4({}, 'a', {
enumerable: true,
get: function get() {
defineProperty$4(this, 'b', {
value: 3,
enumerable: false
});
}
}), {
b: 2
})).b !== 1) return true; // should work with symbols and should have deterministic property order (V8 bug)
var A = {};
var B = {};
/* global Symbol -- required for testing */
var symbol = Symbol();
var alphabet = 'abcdefghijklmnopqrst';
A[symbol] = 7;
alphabet.split('').forEach(function (chr) {
B[chr] = chr;
});
return nativeAssign({}, A)[symbol] != 7 || objectKeys(nativeAssign({}, B)).join('') != alphabet;
}) ? function assign(target, source) {
// eslint-disable-line no-unused-vars -- required for `.length`
var T = toObject(target);
var argumentsLength = arguments.length;
var index = 1;
var getOwnPropertySymbols = objectGetOwnPropertySymbols.f;
var propertyIsEnumerable = objectPropertyIsEnumerable.f;
while (argumentsLength > index) {
var S = indexedObject(arguments[index++]);
var keys = getOwnPropertySymbols ? objectKeys(S).concat(getOwnPropertySymbols(S)) : objectKeys(S);
var length = keys.length;
var j = 0;
var key;
while (length > j) {
key = keys[j++];
if (!descriptors || propertyIsEnumerable.call(S, key)) T[key] = S[key];
}
}
return T;
} : nativeAssign;
// https://tc39.es/ecma262/#sec-object.assign
_export({
target: 'Object',
stat: true,
forced: Object.assign !== objectAssign
}, {
assign: objectAssign
});
var es_object_assign = {};
// https://tc39.es/ecma262/#sec-object.create
_export({
target: 'Object',
stat: true,
sham: !descriptors
}, {
create: objectCreate
});
var es_object_create = {};
// https://tc39.es/ecma262/#sec-object.defineproperty
_export({
target: 'Object',
stat: true,
forced: !descriptors,
sham: !descriptors
}, {
defineProperty: objectDefineProperty.f
});
var es_object_defineProperty = {};
// https://tc39.es/ecma262/#sec-object.defineproperties
_export({
target: 'Object',
stat: true,
forced: !descriptors,
sham: !descriptors
}, {
defineProperties: objectDefineProperties
});
var es_object_defineProperties = {};
var propertyIsEnumerable = objectPropertyIsEnumerable.f; // `Object.{ entries, values }` methods implementation
var createMethod$3 = function createMethod(TO_ENTRIES) {
return function (it) {
var O = toIndexedObject(it);
var keys = objectKeys(O);
var length = keys.length;
var i = 0;
var result = [];
var key;
while (length > i) {
key = keys[i++];
if (!descriptors || propertyIsEnumerable.call(O, key)) {
result.push(TO_ENTRIES ? [key, O[key]] : O[key]);
}
}
return result;
};
};
var objectToArray = {
// `Object.entries` method
// https://tc39.es/ecma262/#sec-object.entries
entries: createMethod$3(true),
// `Object.values` method
// https://tc39.es/ecma262/#sec-object.values
values: createMethod$3(false)
};
var objectToArray_1 = objectToArray.entries;
var objectToArray_2 = objectToArray.values;
var $entries = objectToArray.entries; // `Object.entries` method
// https://tc39.es/ecma262/#sec-object.entries
_export({
target: 'Object',
stat: true
}, {
entries: function entries(O) {
return $entries(O);
}
});
var es_object_entries = {};
var freezing = !fails(function () {
return Object.isExtensible(Object.preventExtensions({}));
});
var internalMetadata = createCommonjsModule(function (module) {
var defineProperty = objectDefineProperty.f;
var METADATA = uid('meta');
var id = 0;
var isExtensible = Object.isExtensible || function () {
return true;
};
var setMetadata = function setMetadata(it) {
defineProperty(it, METADATA, {
value: {
objectID: 'O' + ++id,
// object ID
weakData: {} // weak collections IDs
}
});
};
var fastKey = function fastKey(it, create) {
// return a primitive with prefix
if (!isObject(it)) return typeof it == 'symbol' ? it : (typeof it == 'string' ? 'S' : 'P') + it;
if (!has(it, METADATA)) {
// can't set metadata to uncaught frozen object
if (!isExtensible(it)) return 'F'; // not necessary to add metadata
if (!create) return 'E'; // add missing metadata
setMetadata(it); // return object ID
}
return it[METADATA].objectID;
};
var getWeakData = function getWeakData(it, create) {
if (!has(it, METADATA)) {
// can't set metadata to uncaught frozen object
if (!isExtensible(it)) return true; // not necessary to add metadata
if (!create) return false; // add missing metadata
setMetadata(it); // return the store of weak collections IDs
}
return it[METADATA].weakData;
}; // add metadata on freeze-family methods calling
var onFreeze = function onFreeze(it) {
if (freezing && meta.REQUIRED && isExtensible(it) && !has(it, METADATA)) setMetadata(it);
return it;
};
var meta = module.exports = {
REQUIRED: false,
fastKey: fastKey,
getWeakData: getWeakData,
onFreeze: onFreeze
};
hiddenKeys[METADATA] = true;
});
var internalMetadata_1 = internalMetadata.REQUIRED;
var internalMetadata_2 = internalMetadata.fastKey;
var internalMetadata_3 = internalMetadata.getWeakData;
var internalMetadata_4 = internalMetadata.onFreeze;
var onFreeze = internalMetadata.onFreeze;
var nativeFreeze = Object.freeze;
var FAILS_ON_PRIMITIVES = fails(function () {
nativeFreeze(1);
}); // `Object.freeze` method
// https://tc39.es/ecma262/#sec-object.freeze
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES,
sham: !freezing
}, {
freeze: function freeze(it) {
return nativeFreeze && isObject(it) ? nativeFreeze(onFreeze(it)) : it;
}
});
var es_object_freeze = {};
// https://github.com/tc39/proposal-object-from-entries
_export({
target: 'Object',
stat: true
}, {
fromEntries: function fromEntries(iterable) {
var obj = {};
iterate(iterable, function (k, v) {
createProperty(obj, k, v);
}, {
AS_ENTRIES: true
});
return obj;
}
});
var es_object_fromEntries = {};
var nativeGetOwnPropertyDescriptor$2 = objectGetOwnPropertyDescriptor.f;
var FAILS_ON_PRIMITIVES$1 = fails(function () {
nativeGetOwnPropertyDescriptor$2(1);
});
var FORCED$3 = !descriptors || FAILS_ON_PRIMITIVES$1; // `Object.getOwnPropertyDescriptor` method
// https://tc39.es/ecma262/#sec-object.getownpropertydescriptor
_export({
target: 'Object',
stat: true,
forced: FORCED$3,
sham: !descriptors
}, {
getOwnPropertyDescriptor: function getOwnPropertyDescriptor(it, key) {
return nativeGetOwnPropertyDescriptor$2(toIndexedObject(it), key);
}
});
var es_object_getOwnPropertyDescriptor = {};
// https://tc39.es/ecma262/#sec-object.getownpropertydescriptors
_export({
target: 'Object',
stat: true,
sham: !descriptors
}, {
getOwnPropertyDescriptors: function getOwnPropertyDescriptors(object) {
var O = toIndexedObject(object);
var getOwnPropertyDescriptor = objectGetOwnPropertyDescriptor.f;
var keys = ownKeys(O);
var result = {};
var index = 0;
var key, descriptor;
while (keys.length > index) {
descriptor = getOwnPropertyDescriptor(O, key = keys[index++]);
if (descriptor !== undefined) createProperty(result, key, descriptor);
}
return result;
}
});
var es_object_getOwnPropertyDescriptors = {};
var nativeGetOwnPropertyNames$2 = objectGetOwnPropertyNamesExternal.f;
var FAILS_ON_PRIMITIVES$2 = fails(function () {
return !Object.getOwnPropertyNames(1);
}); // `Object.getOwnPropertyNames` method
// https://tc39.es/ecma262/#sec-object.getownpropertynames
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$2
}, {
getOwnPropertyNames: nativeGetOwnPropertyNames$2
});
var es_object_getOwnPropertyNames = {};
var FAILS_ON_PRIMITIVES$3 = fails(function () {
objectGetPrototypeOf(1);
}); // `Object.getPrototypeOf` method
// https://tc39.es/ecma262/#sec-object.getprototypeof
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$3,
sham: !correctPrototypeGetter
}, {
getPrototypeOf: function getPrototypeOf(it) {
return objectGetPrototypeOf(toObject(it));
}
});
var es_object_getPrototypeOf = {};
// `SameValue` abstract operation
// https://tc39.es/ecma262/#sec-samevalue
var sameValue = Object.is || function is(x, y) {
// eslint-disable-next-line no-self-compare -- NaN check
return x === y ? x !== 0 || 1 / x === 1 / y : x != x && y != y;
};
// https://tc39.es/ecma262/#sec-object.is
_export({
target: 'Object',
stat: true
}, {
is: sameValue
});
var es_object_is = {};
var nativeIsExtensible = Object.isExtensible;
var FAILS_ON_PRIMITIVES$4 = fails(function () {
nativeIsExtensible(1);
}); // `Object.isExtensible` method
// https://tc39.es/ecma262/#sec-object.isextensible
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$4
}, {
isExtensible: function isExtensible(it) {
return isObject(it) ? nativeIsExtensible ? nativeIsExtensible(it) : true : false;
}
});
var es_object_isExtensible = {};
var nativeIsFrozen = Object.isFrozen;
var FAILS_ON_PRIMITIVES$5 = fails(function () {
nativeIsFrozen(1);
}); // `Object.isFrozen` method
// https://tc39.es/ecma262/#sec-object.isfrozen
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$5
}, {
isFrozen: function isFrozen(it) {
return isObject(it) ? nativeIsFrozen ? nativeIsFrozen(it) : false : true;
}
});
var es_object_isFrozen = {};
var nativeIsSealed = Object.isSealed;
var FAILS_ON_PRIMITIVES$6 = fails(function () {
nativeIsSealed(1);
}); // `Object.isSealed` method
// https://tc39.es/ecma262/#sec-object.issealed
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$6
}, {
isSealed: function isSealed(it) {
return isObject(it) ? nativeIsSealed ? nativeIsSealed(it) : false : true;
}
});
var es_object_isSealed = {};
var FAILS_ON_PRIMITIVES$7 = fails(function () {
objectKeys(1);
}); // `Object.keys` method
// https://tc39.es/ecma262/#sec-object.keys
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$7
}, {
keys: function keys(it) {
return objectKeys(toObject(it));
}
});
var es_object_keys = {};
var onFreeze$1 = internalMetadata.onFreeze;
var nativePreventExtensions = Object.preventExtensions;
var FAILS_ON_PRIMITIVES$8 = fails(function () {
nativePreventExtensions(1);
}); // `Object.preventExtensions` method
// https://tc39.es/ecma262/#sec-object.preventextensions
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$8,
sham: !freezing
}, {
preventExtensions: function preventExtensions(it) {
return nativePreventExtensions && isObject(it) ? nativePreventExtensions(onFreeze$1(it)) : it;
}
});
var es_object_preventExtensions = {};
var onFreeze$2 = internalMetadata.onFreeze;
var nativeSeal = Object.seal;
var FAILS_ON_PRIMITIVES$9 = fails(function () {
nativeSeal(1);
}); // `Object.seal` method
// https://tc39.es/ecma262/#sec-object.seal
_export({
target: 'Object',
stat: true,
forced: FAILS_ON_PRIMITIVES$9,
sham: !freezing
}, {
seal: function seal(it) {
return nativeSeal && isObject(it) ? nativeSeal(onFreeze$2(it)) : it;
}
});
var es_object_seal = {};
// https://tc39.es/ecma262/#sec-object.setprototypeof
_export({
target: 'Object',
stat: true
}, {
setPrototypeOf: objectSetPrototypeOf
});
var es_object_setPrototypeOf = {};
var $values = objectToArray.values; // `Object.values` method
// https://tc39.es/ecma262/#sec-object.values
_export({
target: 'Object',
stat: true
}, {
values: function values(O) {
return $values(O);
}
});
var es_object_values = {};
'use strict'; // `Object.prototype.toString` method implementation
// https://tc39.es/ecma262/#sec-object.prototype.tostring
var objectToString = toStringTagSupport ? {}.toString : function toString() {
return '[object ' + classof(this) + ']';
};
// https://tc39.es/ecma262/#sec-object.prototype.tostring
if (!toStringTagSupport) {
redefine(Object.prototype, 'toString', objectToString, {
unsafe: true
});
}
var es_object_toString = {};
'use strict'; // Forced replacement object prototype accessors methods
var objectPrototypeAccessorsForced = isPure || !fails(function () {
var key = Math.random(); // In FF throws only define methods
// eslint-disable-next-line no-undef, no-useless-call -- required for testing
__defineSetter__.call(null, key, function () {
/* empty */
});
delete global_1[key];
});
'use strict'; // `Object.prototype.__defineGetter__` method
// https://tc39.es/ecma262/#sec-object.prototype.__defineGetter__
if (descriptors) {
_export({
target: 'Object',
proto: true,
forced: objectPrototypeAccessorsForced
}, {
__defineGetter__: function __defineGetter__(P, getter) {
objectDefineProperty.f(toObject(this), P, {
get: aFunction$1(getter),
enumerable: true,
configurable: true
});
}
});
}
var es_object_defineGetter = {};
'use strict'; // `Object.prototype.__defineSetter__` method
// https://tc39.es/ecma262/#sec-object.prototype.__defineSetter__
if (descriptors) {
_export({
target: 'Object',
proto: true,
forced: objectPrototypeAccessorsForced
}, {
__defineSetter__: function __defineSetter__(P, setter) {
objectDefineProperty.f(toObject(this), P, {
set: aFunction$1(setter),
enumerable: true,
configurable: true
});
}
});
}
var es_object_defineSetter = {};
'use strict';
var getOwnPropertyDescriptor$2 = objectGetOwnPropertyDescriptor.f; // `Object.prototype.__lookupGetter__` method
// https://tc39.es/ecma262/#sec-object.prototype.__lookupGetter__
if (descriptors) {
_export({
target: 'Object',
proto: true,
forced: objectPrototypeAccessorsForced
}, {
__lookupGetter__: function __lookupGetter__(P) {
var O = toObject(this);
var key = toPrimitive(P, true);
var desc;
do {
if (desc = getOwnPropertyDescriptor$2(O, key)) return desc.get;
} while (O = objectGetPrototypeOf(O));
}
});
}
var es_object_lookupGetter = {};
'use strict';
var getOwnPropertyDescriptor$3 = objectGetOwnPropertyDescriptor.f; // `Object.prototype.__lookupSetter__` method
// https://tc39.es/ecma262/#sec-object.prototype.__lookupSetter__
if (descriptors) {
_export({
target: 'Object',
proto: true,
forced: objectPrototypeAccessorsForced
}, {
__lookupSetter__: function __lookupSetter__(P) {
var O = toObject(this);
var key = toPrimitive(P, true);
var desc;
do {
if (desc = getOwnPropertyDescriptor$3(O, key)) return desc.set;
} while (O = objectGetPrototypeOf(O));
}
});
}
var es_object_lookupSetter = {};
var fromCharCode = String.fromCharCode;
var nativeFromCodePoint = String.fromCodePoint; // length should be 1, old FF problem
var INCORRECT_LENGTH = !!nativeFromCodePoint && nativeFromCodePoint.length != 1; // `String.fromCodePoint` method
// https://tc39.es/ecma262/#sec-string.fromcodepoint
_export({
target: 'String',
stat: true,
forced: INCORRECT_LENGTH
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
fromCodePoint: function fromCodePoint(x) {
var elements = [];
var length = arguments.length;
var i = 0;
var code;
while (length > i) {
code = +arguments[i++];
if (toAbsoluteIndex(code, 0x10FFFF) !== code) throw RangeError(code + ' is not a valid code point');
elements.push(code < 0x10000 ? fromCharCode(code) : fromCharCode(((code -= 0x10000) >> 10) + 0xD800, code % 0x400 + 0xDC00));
}
return elements.join('');
}
});
var es_string_fromCodePoint = {};
// https://tc39.es/ecma262/#sec-string.raw
_export({
target: 'String',
stat: true
}, {
raw: function raw(template) {
var rawTemplate = toIndexedObject(template.raw);
var literalSegments = toLength(rawTemplate.length);
var argumentsLength = arguments.length;
var elements = [];
var i = 0;
while (literalSegments > i) {
elements.push(String(rawTemplate[i++]));
if (i < argumentsLength) elements.push(String(arguments[i]));
}
return elements.join('');
}
});
var es_string_raw = {};
var createMethod$4 = function createMethod(CONVERT_TO_STRING) {
return function ($this, pos) {
var S = String(requireObjectCoercible($this));
var position = toInteger(pos);
var size = S.length;
var first, second;
if (position < 0 || position >= size) return CONVERT_TO_STRING ? '' : undefined;
first = S.charCodeAt(position);
return first < 0xD800 || first > 0xDBFF || position + 1 === size || (second = S.charCodeAt(position + 1)) < 0xDC00 || second > 0xDFFF ? CONVERT_TO_STRING ? S.charAt(position) : first : CONVERT_TO_STRING ? S.slice(position, position + 2) : (first - 0xD800 << 10) + (second - 0xDC00) + 0x10000;
};
};
var stringMultibyte = {
// `String.prototype.codePointAt` method
// https://tc39.es/ecma262/#sec-string.prototype.codepointat
codeAt: createMethod$4(false),
// `String.prototype.at` method
// https://github.com/mathiasbynens/String.prototype.at
charAt: createMethod$4(true)
};
var stringMultibyte_1 = stringMultibyte.codeAt;
var stringMultibyte_2 = stringMultibyte.charAt;
'use strict';
var codeAt = stringMultibyte.codeAt; // `String.prototype.codePointAt` method
// https://tc39.es/ecma262/#sec-string.prototype.codepointat
_export({
target: 'String',
proto: true
}, {
codePointAt: function codePointAt(pos) {
return codeAt(this, pos);
}
});
var es_string_codePointAt = {};
var MATCH = wellKnownSymbol('match'); // `IsRegExp` abstract operation
// https://tc39.es/ecma262/#sec-isregexp
var isRegexp = function isRegexp(it) {
var isRegExp;
return isObject(it) && ((isRegExp = it[MATCH]) !== undefined ? !!isRegExp : classofRaw(it) == 'RegExp');
};
var notARegexp = function notARegexp(it) {
if (isRegexp(it)) {
throw TypeError("The method doesn't accept regular expressions");
}
return it;
};
var MATCH$1 = wellKnownSymbol('match');
var correctIsRegexpLogic = function correctIsRegexpLogic(METHOD_NAME) {
var regexp = /./;
try {
'/./'[METHOD_NAME](regexp);
} catch (error1) {
try {
regexp[MATCH$1] = false;
return '/./'[METHOD_NAME](regexp);
} catch (error2) {
/* empty */
}
}
return false;
};
'use strict';
var getOwnPropertyDescriptor$4 = objectGetOwnPropertyDescriptor.f;
var nativeEndsWith = ''.endsWith;
var min$5 = Math.min;
var CORRECT_IS_REGEXP_LOGIC = correctIsRegexpLogic('endsWith'); // https://github.com/zloirock/core-js/pull/702
var MDN_POLYFILL_BUG = !isPure && !CORRECT_IS_REGEXP_LOGIC && !!function () {
var descriptor = getOwnPropertyDescriptor$4(String.prototype, 'endsWith');
return descriptor && !descriptor.writable;
}(); // `String.prototype.endsWith` method
// https://tc39.es/ecma262/#sec-string.prototype.endswith
_export({
target: 'String',
proto: true,
forced: !MDN_POLYFILL_BUG && !CORRECT_IS_REGEXP_LOGIC
}, {
endsWith: function endsWith(searchString
/* , endPosition = @length */
) {
var that = String(requireObjectCoercible(this));
notARegexp(searchString);
var endPosition = arguments.length > 1 ? arguments[1] : undefined;
var len = toLength(that.length);
var end = endPosition === undefined ? len : min$5(toLength(endPosition), len);
var search = String(searchString);
return nativeEndsWith ? nativeEndsWith.call(that, search, end) : that.slice(end - search.length, end) === search;
}
});
var es_string_endsWith = {};
'use strict'; // `String.prototype.includes` method
// https://tc39.es/ecma262/#sec-string.prototype.includes
_export({
target: 'String',
proto: true,
forced: !correctIsRegexpLogic('includes')
}, {
includes: function includes(searchString
/* , position = 0 */
) {
return !!~String(requireObjectCoercible(this)).indexOf(notARegexp(searchString), arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_string_includes = {};
'use strict'; // `RegExp.prototype.flags` getter implementation
// https://tc39.es/ecma262/#sec-get-regexp.prototype.flags
var regexpFlags = function regexpFlags() {
var that = anObject(this);
var result = '';
if (that.global) result += 'g';
if (that.ignoreCase) result += 'i';
if (that.multiline) result += 'm';
if (that.dotAll) result += 's';
if (that.unicode) result += 'u';
if (that.sticky) result += 'y';
return result;
};
'use strict'; // babel-minify transpiles RegExp('a', 'y') -> /a/y and it causes SyntaxError,
// so we use an intermediate function.
function RE(s, f) {
return RegExp(s, f);
}
var UNSUPPORTED_Y = fails(function () {
// babel-minify transpiles RegExp('a', 'y') -> /a/y and it causes SyntaxError
var re = RE('a', 'y');
re.lastIndex = 2;
return re.exec('abcd') != null;
});
var BROKEN_CARET = fails(function () {
// https://bugzilla.mozilla.org/show_bug.cgi?id=773687
var re = RE('^r', 'gy');
re.lastIndex = 2;
return re.exec('str') != null;
});
var regexpStickyHelpers = {
UNSUPPORTED_Y: UNSUPPORTED_Y,
BROKEN_CARET: BROKEN_CARET
};
'use strict';
var nativeExec = RegExp.prototype.exec; // This always refers to the native implementation, because the
// String#replace polyfill uses ./fix-regexp-well-known-symbol-logic.js,
// which loads this file before patching the method.
var nativeReplace = String.prototype.replace;
var patchedExec = nativeExec;
var UPDATES_LAST_INDEX_WRONG = function () {
var re1 = /a/;
var re2 = /b*/g;
nativeExec.call(re1, 'a');
nativeExec.call(re2, 'a');
return re1.lastIndex !== 0 || re2.lastIndex !== 0;
}();
var UNSUPPORTED_Y$1 = regexpStickyHelpers.UNSUPPORTED_Y || regexpStickyHelpers.BROKEN_CARET; // nonparticipating capturing group, copied from es5-shim's String#split patch.
// eslint-disable-next-line regexp/no-assertion-capturing-group, regexp/no-empty-group -- required for testing
var NPCG_INCLUDED = /()??/.exec('')[1] !== undefined;
var PATCH = UPDATES_LAST_INDEX_WRONG || NPCG_INCLUDED || UNSUPPORTED_Y$1;
if (PATCH) {
patchedExec = function exec(str) {
var re = this;
var lastIndex, reCopy, match, i;
var sticky = UNSUPPORTED_Y$1 && re.sticky;
var flags = regexpFlags.call(re);
var source = re.source;
var charsAdded = 0;
var strCopy = str;
if (sticky) {
flags = flags.replace('y', '');
if (flags.indexOf('g') === -1) {
flags += 'g';
}
strCopy = String(str).slice(re.lastIndex); // Support anchored sticky behavior.
if (re.lastIndex > 0 && (!re.multiline || re.multiline && str[re.lastIndex - 1] !== '\n')) {
source = '(?: ' + source + ')';
strCopy = ' ' + strCopy;
charsAdded++;
} // ^(? + rx + ) is needed, in combination with some str slicing, to
// simulate the 'y' flag.
reCopy = new RegExp('^(?:' + source + ')', flags);
}
if (NPCG_INCLUDED) {
reCopy = new RegExp('^' + source + '$(?!\\s)', flags);
}
if (UPDATES_LAST_INDEX_WRONG) lastIndex = re.lastIndex;
match = nativeExec.call(sticky ? reCopy : re, strCopy);
if (sticky) {
if (match) {
match.input = match.input.slice(charsAdded);
match[0] = match[0].slice(charsAdded);
match.index = re.lastIndex;
re.lastIndex += match[0].length;
} else re.lastIndex = 0;
} else if (UPDATES_LAST_INDEX_WRONG && match) {
re.lastIndex = re.global ? match.index + match[0].length : lastIndex;
}
if (NPCG_INCLUDED && match && match.length > 1) {
// Fix browsers whose `exec` methods don't consistently return `undefined`
// for NPCG, like IE8. NOTE: This doesn' work for /(.?)?/
nativeReplace.call(match[0], reCopy, function () {
for (i = 1; i < arguments.length - 2; i++) {
if (arguments[i] === undefined) match[i] = undefined;
}
});
}
return match;
};
}
var regexpExec = patchedExec;
'use strict'; // `RegExp.prototype.exec` method
// https://tc39.es/ecma262/#sec-regexp.prototype.exec
_export({
target: 'RegExp',
proto: true,
forced: /./.exec !== regexpExec
}, {
exec: regexpExec
});
var es_regexp_exec = {};
'use strict'; // TODO: Remove from `core-js@4` since it's moved to entry points
var SPECIES$4 = wellKnownSymbol('species');
var REPLACE_SUPPORTS_NAMED_GROUPS = !fails(function () {
// #replace needs built-in support for named groups.
// #match works fine because it just return the exec results, even if it has
// a "grops" property.
var re = /./;
re.exec = function () {
var result = [];
result.groups = {
a: '7'
};
return result;
};
return ''.replace(re, '$<a>') !== '7';
}); // IE <= 11 replaces $0 with the whole match, as if it was $&
// https://stackoverflow.com/questions/6024666/getting-ie-to-replace-a-regex-with-the-literal-string-0
var REPLACE_KEEPS_$0 = function () {
return 'a'.replace(/./, '$0') === '$0';
}();
var REPLACE = wellKnownSymbol('replace'); // Safari <= 13.0.3(?) substitutes nth capture where n>m with an empty string
var REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE = function () {
if (/./[REPLACE]) {
return /./[REPLACE]('a', '$0') === '';
}
return false;
}(); // Chrome 51 has a buggy "split" implementation when RegExp#exec !== nativeExec
// Weex JS has frozen built-in prototypes, so use try / catch wrapper
var SPLIT_WORKS_WITH_OVERWRITTEN_EXEC = !fails(function () {
// eslint-disable-next-line regexp/no-empty-group -- required for testing
var re = /(?:)/;
var originalExec = re.exec;
re.exec = function () {
return originalExec.apply(this, arguments);
};
var result = 'ab'.split(re);
return result.length !== 2 || result[0] !== 'a' || result[1] !== 'b';
});
var fixRegexpWellKnownSymbolLogic = function fixRegexpWellKnownSymbolLogic(KEY, length, exec, sham) {
var SYMBOL = wellKnownSymbol(KEY);
var DELEGATES_TO_SYMBOL = !fails(function () {
// String methods call symbol-named RegEp methods
var O = {};
O[SYMBOL] = function () {
return 7;
};
return ''[KEY](O) != 7;
});
var DELEGATES_TO_EXEC = DELEGATES_TO_SYMBOL && !fails(function () {
// Symbol-named RegExp methods call .exec
var execCalled = false;
var re = /a/;
if (KEY === 'split') {
// We can't use real regex here since it causes deoptimization
// and serious performance degradation in V8
// https://github.com/zloirock/core-js/issues/306
re = {}; // RegExp[@@split] doesn't call the regex's exec method, but first creates
// a new one. We need to return the patched regex when creating the new one.
re.constructor = {};
re.constructor[SPECIES$4] = function () {
return re;
};
re.flags = '';
re[SYMBOL] = /./[SYMBOL];
}
re.exec = function () {
execCalled = true;
return null;
};
re[SYMBOL]('');
return !execCalled;
});
if (!DELEGATES_TO_SYMBOL || !DELEGATES_TO_EXEC || KEY === 'replace' && !(REPLACE_SUPPORTS_NAMED_GROUPS && REPLACE_KEEPS_$0 && !REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE) || KEY === 'split' && !SPLIT_WORKS_WITH_OVERWRITTEN_EXEC) {
var nativeRegExpMethod = /./[SYMBOL];
var methods = exec(SYMBOL, ''[KEY], function (nativeMethod, regexp, str, arg2, forceStringMethod) {
if (regexp.exec === regexpExec) {
if (DELEGATES_TO_SYMBOL && !forceStringMethod) {
// The native String method already delegates to @@method (this
// polyfilled function), leasing to infinite recursion.
// We avoid it by directly calling the native @@method method.
return {
done: true,
value: nativeRegExpMethod.call(regexp, str, arg2)
};
}
return {
done: true,
value: nativeMethod.call(str, regexp, arg2)
};
}
return {
done: false
};
}, {
REPLACE_KEEPS_$0: REPLACE_KEEPS_$0,
REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE: REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE
});
var stringMethod = methods[0];
var regexMethod = methods[1];
redefine(String.prototype, KEY, stringMethod);
redefine(RegExp.prototype, SYMBOL, length == 2 // 21.2.5.8 RegExp.prototype[@@replace](string, replaceValue)
// 21.2.5.11 RegExp.prototype[@@split](string, limit)
? function (string, arg) {
return regexMethod.call(string, this, arg);
} // 21.2.5.6 RegExp.prototype[@@match](string)
// 21.2.5.9 RegExp.prototype[@@search](string)
: function (string) {
return regexMethod.call(string, this);
});
}
if (sham) createNonEnumerableProperty(RegExp.prototype[SYMBOL], 'sham', true);
};
'use strict';
var charAt = stringMultibyte.charAt; // `AdvanceStringIndex` abstract operation
// https://tc39.es/ecma262/#sec-advancestringindex
var advanceStringIndex = function advanceStringIndex(S, index, unicode) {
return index + (unicode ? charAt(S, index).length : 1);
};
// https://tc39.es/ecma262/#sec-regexpexec
var regexpExecAbstract = function regexpExecAbstract(R, S) {
var exec = R.exec;
if (typeof exec === 'function') {
var result = exec.call(R, S);
if (typeof result !== 'object') {
throw TypeError('RegExp exec method returned something other than an Object or null');
}
return result;
}
if (classofRaw(R) !== 'RegExp') {
throw TypeError('RegExp#exec called on incompatible receiver');
}
return regexpExec.call(R, S);
};
'use strict'; // @@match logic
fixRegexpWellKnownSymbolLogic('match', 1, function (MATCH, nativeMatch, maybeCallNative) {
return [// `String.prototype.match` method
// https://tc39.es/ecma262/#sec-string.prototype.match
function match(regexp) {
var O = requireObjectCoercible(this);
var matcher = regexp == undefined ? undefined : regexp[MATCH];
return matcher !== undefined ? matcher.call(regexp, O) : new RegExp(regexp)[MATCH](String(O));
}, // `RegExp.prototype[@@match]` method
// https://tc39.es/ecma262/#sec-regexp.prototype-@@match
function (regexp) {
var res = maybeCallNative(nativeMatch, regexp, this);
if (res.done) return res.value;
var rx = anObject(regexp);
var S = String(this);
if (!rx.global) return regexpExecAbstract(rx, S);
var fullUnicode = rx.unicode;
rx.lastIndex = 0;
var A = [];
var n = 0;
var result;
while ((result = regexpExecAbstract(rx, S)) !== null) {
var matchStr = String(result[0]);
A[n] = matchStr;
if (matchStr === '') rx.lastIndex = advanceStringIndex(S, toLength(rx.lastIndex), fullUnicode);
n++;
}
return n === 0 ? null : A;
}];
});
var es_string_match = {};
var SPECIES$5 = wellKnownSymbol('species'); // `SpeciesConstructor` abstract operation
// https://tc39.es/ecma262/#sec-speciesconstructor
var speciesConstructor = function speciesConstructor(O, defaultConstructor) {
var C = anObject(O).constructor;
var S;
return C === undefined || (S = anObject(C)[SPECIES$5]) == undefined ? defaultConstructor : aFunction$1(S);
};
'use strict';
var MATCH_ALL = wellKnownSymbol('matchAll');
var REGEXP_STRING = 'RegExp String';
var REGEXP_STRING_ITERATOR = REGEXP_STRING + ' Iterator';
var setInternalState$2 = internalState.set;
var getInternalState$2 = internalState.getterFor(REGEXP_STRING_ITERATOR);
var RegExpPrototype = RegExp.prototype;
var regExpBuiltinExec = RegExpPrototype.exec;
var nativeMatchAll = ''.matchAll;
var WORKS_WITH_NON_GLOBAL_REGEX = !!nativeMatchAll && !fails(function () {
'a'.matchAll(/./);
});
var regExpExec = function regExpExec(R, S) {
var exec = R.exec;
var result;
if (typeof exec == 'function') {
result = exec.call(R, S);
if (typeof result != 'object') throw TypeError('Incorrect exec result');
return result;
}
return regExpBuiltinExec.call(R, S);
}; // eslint-disable-next-line max-len -- ignore
var $RegExpStringIterator = createIteratorConstructor(function RegExpStringIterator(regexp, string, global, fullUnicode) {
setInternalState$2(this, {
type: REGEXP_STRING_ITERATOR,
regexp: regexp,
string: string,
global: global,
unicode: fullUnicode,
done: false
});
}, REGEXP_STRING, function next() {
var state = getInternalState$2(this);
if (state.done) return {
value: undefined,
done: true
};
var R = state.regexp;
var S = state.string;
var match = regExpExec(R, S);
if (match === null) return {
value: undefined,
done: state.done = true
};
if (state.global) {
if (String(match[0]) == '') R.lastIndex = advanceStringIndex(S, toLength(R.lastIndex), state.unicode);
return {
value: match,
done: false
};
}
state.done = true;
return {
value: match,
done: false
};
});
var $matchAll = function $matchAll(string) {
var R = anObject(this);
var S = String(string);
var C, flagsValue, flags, matcher, global, fullUnicode;
C = speciesConstructor(R, RegExp);
flagsValue = R.flags;
if (flagsValue === undefined && R instanceof RegExp && !('flags' in RegExpPrototype)) {
flagsValue = regexpFlags.call(R);
}
flags = flagsValue === undefined ? '' : String(flagsValue);
matcher = new C(C === RegExp ? R.source : R, flags);
global = !!~flags.indexOf('g');
fullUnicode = !!~flags.indexOf('u');
matcher.lastIndex = toLength(R.lastIndex);
return new $RegExpStringIterator(matcher, S, global, fullUnicode);
}; // `String.prototype.matchAll` method
// https://tc39.es/ecma262/#sec-string.prototype.matchall
_export({
target: 'String',
proto: true,
forced: WORKS_WITH_NON_GLOBAL_REGEX
}, {
matchAll: function matchAll(regexp) {
var O = requireObjectCoercible(this);
var flags, S, matcher, rx;
if (regexp != null) {
if (isRegexp(regexp)) {
flags = String(requireObjectCoercible('flags' in RegExpPrototype ? regexp.flags : regexpFlags.call(regexp)));
if (!~flags.indexOf('g')) throw TypeError('`.matchAll` does not allow non-global regexes');
}
if (WORKS_WITH_NON_GLOBAL_REGEX) return nativeMatchAll.apply(O, arguments);
matcher = regexp[MATCH_ALL];
if (matcher === undefined && isPure && classofRaw(regexp) == 'RegExp') matcher = $matchAll;
if (matcher != null) return aFunction$1(matcher).call(regexp, O);
} else if (WORKS_WITH_NON_GLOBAL_REGEX) return nativeMatchAll.apply(O, arguments);
S = String(O);
rx = new RegExp(regexp, 'g');
return isPure ? $matchAll.call(rx, S) : rx[MATCH_ALL](S);
}
});
isPure || MATCH_ALL in RegExpPrototype || createNonEnumerableProperty(RegExpPrototype, MATCH_ALL, $matchAll);
var es_string_matchAll = {};
'use strict'; // `String.prototype.repeat` method implementation
// https://tc39.es/ecma262/#sec-string.prototype.repeat
var stringRepeat = ''.repeat || function repeat(count) {
var str = String(requireObjectCoercible(this));
var result = '';
var n = toInteger(count);
if (n < 0 || n == Infinity) throw RangeError('Wrong number of repetitions');
for (; n > 0; (n >>>= 1) && (str += str)) {
if (n & 1) result += str;
}
return result;
};
var ceil$1 = Math.ceil; // `String.prototype.{ padStart, padEnd }` methods implementation
var createMethod$5 = function createMethod(IS_END) {
return function ($this, maxLength, fillString) {
var S = String(requireObjectCoercible($this));
var stringLength = S.length;
var fillStr = fillString === undefined ? ' ' : String(fillString);
var intMaxLength = toLength(maxLength);
var fillLen, stringFiller;
if (intMaxLength <= stringLength || fillStr == '') return S;
fillLen = intMaxLength - stringLength;
stringFiller = stringRepeat.call(fillStr, ceil$1(fillLen / fillStr.length));
if (stringFiller.length > fillLen) stringFiller = stringFiller.slice(0, fillLen);
return IS_END ? S + stringFiller : stringFiller + S;
};
};
var stringPad = {
// `String.prototype.padStart` method
// https://tc39.es/ecma262/#sec-string.prototype.padstart
start: createMethod$5(false),
// `String.prototype.padEnd` method
// https://tc39.es/ecma262/#sec-string.prototype.padend
end: createMethod$5(true)
};
var stringPad_1 = stringPad.start;
var stringPad_2 = stringPad.end;
// eslint-disable-next-line unicorn/no-unsafe-regex -- safe
var stringPadWebkitBug = /Version\/10\.\d+(\.\d+)?( Mobile\/\w+)? Safari\//.test(engineUserAgent);
'use strict';
var $padEnd = stringPad.end; // `String.prototype.padEnd` method
// https://tc39.es/ecma262/#sec-string.prototype.padend
_export({
target: 'String',
proto: true,
forced: stringPadWebkitBug
}, {
padEnd: function padEnd(maxLength
/* , fillString = ' ' */
) {
return $padEnd(this, maxLength, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_string_padEnd = {};
'use strict';
var $padStart = stringPad.start; // `String.prototype.padStart` method
// https://tc39.es/ecma262/#sec-string.prototype.padstart
_export({
target: 'String',
proto: true,
forced: stringPadWebkitBug
}, {
padStart: function padStart(maxLength
/* , fillString = ' ' */
) {
return $padStart(this, maxLength, arguments.length > 1 ? arguments[1] : undefined);
}
});
var es_string_padStart = {};
// https://tc39.es/ecma262/#sec-string.prototype.repeat
_export({
target: 'String',
proto: true
}, {
repeat: stringRepeat
});
var es_string_repeat = {};
var floor$1 = Math.floor;
var replace = ''.replace;
var SUBSTITUTION_SYMBOLS = /\$([$&'`]|\d{1,2}|<[^>]*>)/g;
var SUBSTITUTION_SYMBOLS_NO_NAMED = /\$([$&'`]|\d{1,2})/g; // https://tc39.es/ecma262/#sec-getsubstitution
var getSubstitution = function getSubstitution(matched, str, position, captures, namedCaptures, replacement) {
var tailPos = position + matched.length;
var m = captures.length;
var symbols = SUBSTITUTION_SYMBOLS_NO_NAMED;
if (namedCaptures !== undefined) {
namedCaptures = toObject(namedCaptures);
symbols = SUBSTITUTION_SYMBOLS;
}
return replace.call(replacement, symbols, function (match, ch) {
var capture;
switch (ch.charAt(0)) {
case '$':
return '$';
case '&':
return matched;
case '`':
return str.slice(0, position);
case "'":
return str.slice(tailPos);
case '<':
capture = namedCaptures[ch.slice(1, -1)];
break;
default:
// \d\d?
var n = +ch;
if (n === 0) return match;
if (n > m) {
var f = floor$1(n / 10);
if (f === 0) return match;
if (f <= m) return captures[f - 1] === undefined ? ch.charAt(1) : captures[f - 1] + ch.charAt(1);
return match;
}
capture = captures[n - 1];
}
return capture === undefined ? '' : capture;
});
};
'use strict';
var max$3 = Math.max;
var min$6 = Math.min;
var maybeToString = function maybeToString(it) {
return it === undefined ? it : String(it);
}; // @@replace logic
fixRegexpWellKnownSymbolLogic('replace', 2, function (REPLACE, nativeReplace, maybeCallNative, reason) {
var REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE = reason.REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE;
var REPLACE_KEEPS_$0 = reason.REPLACE_KEEPS_$0;
var UNSAFE_SUBSTITUTE = REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE ? '$' : '$0';
return [// `String.prototype.replace` method
// https://tc39.es/ecma262/#sec-string.prototype.replace
function replace(searchValue, replaceValue) {
var O = requireObjectCoercible(this);
var replacer = searchValue == undefined ? undefined : searchValue[REPLACE];
return replacer !== undefined ? replacer.call(searchValue, O, replaceValue) : nativeReplace.call(String(O), searchValue, replaceValue);
}, // `RegExp.prototype[@@replace]` method
// https://tc39.es/ecma262/#sec-regexp.prototype-@@replace
function (regexp, replaceValue) {
if (!REGEXP_REPLACE_SUBSTITUTES_UNDEFINED_CAPTURE && REPLACE_KEEPS_$0 || typeof replaceValue === 'string' && replaceValue.indexOf(UNSAFE_SUBSTITUTE) === -1) {
var res = maybeCallNative(nativeReplace, regexp, this, replaceValue);
if (res.done) return res.value;
}
var rx = anObject(regexp);
var S = String(this);
var functionalReplace = typeof replaceValue === 'function';
if (!functionalReplace) replaceValue = String(replaceValue);
var global = rx.global;
if (global) {
var fullUnicode = rx.unicode;
rx.lastIndex = 0;
}
var results = [];
while (true) {
var result = regexpExecAbstract(rx, S);
if (result === null) break;
results.push(result);
if (!global) break;
var matchStr = String(result[0]);
if (matchStr === '') rx.lastIndex = advanceStringIndex(S, toLength(rx.lastIndex), fullUnicode);
}
var accumulatedResult = '';
var nextSourcePosition = 0;
for (var i = 0; i < results.length; i++) {
result = results[i];
var matched = String(result[0]);
var position = max$3(min$6(toInteger(result.index), S.length), 0);
var captures = []; // NOTE: This is equivalent to
// captures = result.slice(1).map(maybeToString)
// but for some reason `nativeSlice.call(result, 1, result.length)` (called in
// the slice polyfill when slicing native arrays) "doesn't work" in safari 9 and
// causes a crash (https://pastebin.com/N21QzeQA) when trying to debug it.
for (var j = 1; j < result.length; j++) {
captures.push(maybeToString(result[j]));
}
var namedCaptures = result.groups;
if (functionalReplace) {
var replacerArgs = [matched].concat(captures, position, S);
if (namedCaptures !== undefined) replacerArgs.push(namedCaptures);
var replacement = String(replaceValue.apply(undefined, replacerArgs));
} else {
replacement = getSubstitution(matched, S, position, captures, namedCaptures, replaceValue);
}
if (position >= nextSourcePosition) {
accumulatedResult += S.slice(nextSourcePosition, position) + replacement;
nextSourcePosition = position + matched.length;
}
}
return accumulatedResult + S.slice(nextSourcePosition);
}];
});
var es_string_replace = {};
'use strict'; // @@search logic
fixRegexpWellKnownSymbolLogic('search', 1, function (SEARCH, nativeSearch, maybeCallNative) {
return [// `String.prototype.search` method
// https://tc39.es/ecma262/#sec-string.prototype.search
function search(regexp) {
var O = requireObjectCoercible(this);
var searcher = regexp == undefined ? undefined : regexp[SEARCH];
return searcher !== undefined ? searcher.call(regexp, O) : new RegExp(regexp)[SEARCH](String(O));
}, // `RegExp.prototype[@@search]` method
// https://tc39.es/ecma262/#sec-regexp.prototype-@@search
function (regexp) {
var res = maybeCallNative(nativeSearch, regexp, this);
if (res.done) return res.value;
var rx = anObject(regexp);
var S = String(this);
var previousLastIndex = rx.lastIndex;
if (!sameValue(previousLastIndex, 0)) rx.lastIndex = 0;
var result = regexpExecAbstract(rx, S);
if (!sameValue(rx.lastIndex, previousLastIndex)) rx.lastIndex = previousLastIndex;
return result === null ? -1 : result.index;
}];
});
var es_string_search = {};
'use strict';
var arrayPush = [].push;
var min$7 = Math.min;
var MAX_UINT32 = 0xFFFFFFFF; // babel-minify transpiles RegExp('x', 'y') -> /x/y and it causes SyntaxError
var SUPPORTS_Y = !fails(function () {
return !RegExp(MAX_UINT32, 'y');
}); // @@split logic
fixRegexpWellKnownSymbolLogic('split', 2, function (SPLIT, nativeSplit, maybeCallNative) {
var internalSplit;
if ('abbc'.split(/(b)*/)[1] == 'c' || // eslint-disable-next-line regexp/no-empty-group -- required for testing
'test'.split(/(?:)/, -1).length != 4 || 'ab'.split(/(?:ab)*/).length != 2 || '.'.split(/(.?)(.?)/).length != 4 || // eslint-disable-next-line regexp/no-assertion-capturing-group, regexp/no-empty-group -- required for testing
'.'.split(/()()/).length > 1 || ''.split(/.?/).length) {
// based on es5-shim implementation, need to rework it
internalSplit = function internalSplit(separator, limit) {
var string = String(requireObjectCoercible(this));
var lim = limit === undefined ? MAX_UINT32 : limit >>> 0;
if (lim === 0) return [];
if (separator === undefined) return [string]; // If `separator` is not a regex, use native split
if (!isRegexp(separator)) {
return nativeSplit.call(string, separator, lim);
}
var output = [];
var flags = (separator.ignoreCase ? 'i' : '') + (separator.multiline ? 'm' : '') + (separator.unicode ? 'u' : '') + (separator.sticky ? 'y' : '');
var lastLastIndex = 0; // Make `global` and avoid `lastIndex` issues by working with a copy
var separatorCopy = new RegExp(separator.source, flags + 'g');
var match, lastIndex, lastLength;
while (match = regexpExec.call(separatorCopy, string)) {
lastIndex = separatorCopy.lastIndex;
if (lastIndex > lastLastIndex) {
output.push(string.slice(lastLastIndex, match.index));
if (match.length > 1 && match.index < string.length) arrayPush.apply(output, match.slice(1));
lastLength = match[0].length;
lastLastIndex = lastIndex;
if (output.length >= lim) break;
}
if (separatorCopy.lastIndex === match.index) separatorCopy.lastIndex++; // Avoid an infinite loop
}
if (lastLastIndex === string.length) {
if (lastLength || !separatorCopy.test('')) output.push('');
} else output.push(string.slice(lastLastIndex));
return output.length > lim ? output.slice(0, lim) : output;
}; // Chakra, V8
} else if ('0'.split(undefined, 0).length) {
internalSplit = function internalSplit(separator, limit) {
return separator === undefined && limit === 0 ? [] : nativeSplit.call(this, separator, limit);
};
} else internalSplit = nativeSplit;
return [// `String.prototype.split` method
// https://tc39.es/ecma262/#sec-string.prototype.split
function split(separator, limit) {
var O = requireObjectCoercible(this);
var splitter = separator == undefined ? undefined : separator[SPLIT];
return splitter !== undefined ? splitter.call(separator, O, limit) : internalSplit.call(String(O), separator, limit);
}, // `RegExp.prototype[@@split]` method
// https://tc39.es/ecma262/#sec-regexp.prototype-@@split
//
// NOTE: This cannot be properly polyfilled in engines that don't support
// the 'y' flag.
function (regexp, limit) {
var res = maybeCallNative(internalSplit, regexp, this, limit, internalSplit !== nativeSplit);
if (res.done) return res.value;
var rx = anObject(regexp);
var S = String(this);
var C = speciesConstructor(rx, RegExp);
var unicodeMatching = rx.unicode;
var flags = (rx.ignoreCase ? 'i' : '') + (rx.multiline ? 'm' : '') + (rx.unicode ? 'u' : '') + (SUPPORTS_Y ? 'y' : 'g'); // ^(? + rx + ) is needed, in combination with some S slicing, to
// simulate the 'y' flag.
var splitter = new C(SUPPORTS_Y ? rx : '^(?:' + rx.source + ')', flags);
var lim = limit === undefined ? MAX_UINT32 : limit >>> 0;
if (lim === 0) return [];
if (S.length === 0) return regexpExecAbstract(splitter, S) === null ? [S] : [];
var p = 0;
var q = 0;
var A = [];
while (q < S.length) {
splitter.lastIndex = SUPPORTS_Y ? q : 0;
var z = regexpExecAbstract(splitter, SUPPORTS_Y ? S : S.slice(q));
var e;
if (z === null || (e = min$7(toLength(splitter.lastIndex + (SUPPORTS_Y ? 0 : q)), S.length)) === p) {
q = advanceStringIndex(S, q, unicodeMatching);
} else {
A.push(S.slice(p, q));
if (A.length === lim) return A;
for (var i = 1; i <= z.length - 1; i++) {
A.push(z[i]);
if (A.length === lim) return A;
}
q = p = e;
}
}
A.push(S.slice(p));
return A;
}];
}, !SUPPORTS_Y);
var es_string_split = {};
'use strict';
var getOwnPropertyDescriptor$5 = objectGetOwnPropertyDescriptor.f;
var nativeStartsWith = ''.startsWith;
var min$8 = Math.min;
var CORRECT_IS_REGEXP_LOGIC$1 = correctIsRegexpLogic('startsWith'); // https://github.com/zloirock/core-js/pull/702
var MDN_POLYFILL_BUG$1 = !isPure && !CORRECT_IS_REGEXP_LOGIC$1 && !!function () {
var descriptor = getOwnPropertyDescriptor$5(String.prototype, 'startsWith');
return descriptor && !descriptor.writable;
}(); // `String.prototype.startsWith` method
// https://tc39.es/ecma262/#sec-string.prototype.startswith
_export({
target: 'String',
proto: true,
forced: !MDN_POLYFILL_BUG$1 && !CORRECT_IS_REGEXP_LOGIC$1
}, {
startsWith: function startsWith(searchString
/* , position = 0 */
) {
var that = String(requireObjectCoercible(this));
notARegexp(searchString);
var index = toLength(min$8(arguments.length > 1 ? arguments[1] : undefined, that.length));
var search = String(searchString);
return nativeStartsWith ? nativeStartsWith.call(that, search, index) : that.slice(index, index + search.length) === search;
}
});
var es_string_startsWith = {};
// a string of all valid unicode whitespaces
var whitespaces = "\t\n\x0B\f\r \xA0\u1680\u2000\u2001\u2002" + "\u2003\u2004\u2005\u2006\u2007\u2008\u2009\u200A\u202F\u205F\u3000\u2028\u2029\uFEFF";
var whitespace = '[' + whitespaces + ']';
var ltrim = RegExp('^' + whitespace + whitespace + '*');
var rtrim = RegExp(whitespace + whitespace + '*$'); // `String.prototype.{ trim, trimStart, trimEnd, trimLeft, trimRight }` methods implementation
var createMethod$6 = function createMethod(TYPE) {
return function ($this) {
var string = String(requireObjectCoercible($this));
if (TYPE & 1) string = string.replace(ltrim, '');
if (TYPE & 2) string = string.replace(rtrim, '');
return string;
};
};
var stringTrim = {
// `String.prototype.{ trimLeft, trimStart }` methods
// https://tc39.es/ecma262/#sec-string.prototype.trimstart
start: createMethod$6(1),
// `String.prototype.{ trimRight, trimEnd }` methods
// https://tc39.es/ecma262/#sec-string.prototype.trimend
end: createMethod$6(2),
// `String.prototype.trim` method
// https://tc39.es/ecma262/#sec-string.prototype.trim
trim: createMethod$6(3)
};
var stringTrim_1 = stringTrim.start;
var stringTrim_2 = stringTrim.end;
var stringTrim_3 = stringTrim.trim;
var non = "\u200B\x85\u180E"; // check that a method works with the correct list
// of whitespaces and has a correct name
var stringTrimForced = function stringTrimForced(METHOD_NAME) {
return fails(function () {
return !!whitespaces[METHOD_NAME]() || non[METHOD_NAME]() != non || whitespaces[METHOD_NAME].name !== METHOD_NAME;
});
};
'use strict';
var $trim = stringTrim.trim; // `String.prototype.trim` method
// https://tc39.es/ecma262/#sec-string.prototype.trim
_export({
target: 'String',
proto: true,
forced: stringTrimForced('trim')
}, {
trim: function trim() {
return $trim(this);
}
});
var es_string_trim = {};
'use strict';
var $trimStart = stringTrim.start;
var FORCED$4 = stringTrimForced('trimStart');
var trimStart = FORCED$4 ? function trimStart() {
return $trimStart(this);
} : ''.trimStart; // `String.prototype.{ trimStart, trimLeft }` methods
// https://tc39.es/ecma262/#sec-string.prototype.trimstart
// https://tc39.es/ecma262/#String.prototype.trimleft
_export({
target: 'String',
proto: true,
forced: FORCED$4
}, {
trimStart: trimStart,
trimLeft: trimStart
});
var es_string_trimStart = {};
'use strict';
var $trimEnd = stringTrim.end;
var FORCED$5 = stringTrimForced('trimEnd');
var trimEnd = FORCED$5 ? function trimEnd() {
return $trimEnd(this);
} : ''.trimEnd; // `String.prototype.{ trimEnd, trimRight }` methods
// https://tc39.es/ecma262/#sec-string.prototype.trimend
// https://tc39.es/ecma262/#String.prototype.trimright
_export({
target: 'String',
proto: true,
forced: FORCED$5
}, {
trimEnd: trimEnd,
trimRight: trimEnd
});
var es_string_trimEnd = {};
'use strict';
var charAt$1 = stringMultibyte.charAt;
var STRING_ITERATOR = 'String Iterator';
var setInternalState$3 = internalState.set;
var getInternalState$3 = internalState.getterFor(STRING_ITERATOR); // `String.prototype[@@iterator]` method
// https://tc39.es/ecma262/#sec-string.prototype-@@iterator
defineIterator(String, 'String', function (iterated) {
setInternalState$3(this, {
type: STRING_ITERATOR,
string: String(iterated),
index: 0
}); // `%StringIteratorPrototype%.next` method
// https://tc39.es/ecma262/#sec-%stringiteratorprototype%.next
}, function next() {
var state = getInternalState$3(this);
var string = state.string;
var index = state.index;
var point;
if (index >= string.length) return {
value: undefined,
done: true
};
point = charAt$1(string, index);
state.index += point.length;
return {
value: point,
done: false
};
});
var es_string_iterator = {};
var quot = /"/g; // B.2.3.2.1 CreateHTML(string, tag, attribute, value)
// https://tc39.es/ecma262/#sec-createhtml
var createHtml = function createHtml(string, tag, attribute, value) {
var S = String(requireObjectCoercible(string));
var p1 = '<' + tag;
if (attribute !== '') p1 += ' ' + attribute + '="' + String(value).replace(quot, '&quot;') + '"';
return p1 + '>' + S + '</' + tag + '>';
};
// of a tag and escaping quotes in arguments
var stringHtmlForced = function stringHtmlForced(METHOD_NAME) {
return fails(function () {
var test = ''[METHOD_NAME]('"');
return test !== test.toLowerCase() || test.split('"').length > 3;
});
};
'use strict'; // `String.prototype.anchor` method
// https://tc39.es/ecma262/#sec-string.prototype.anchor
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('anchor')
}, {
anchor: function anchor(name) {
return createHtml(this, 'a', 'name', name);
}
});
var es_string_anchor = {};
'use strict'; // `String.prototype.big` method
// https://tc39.es/ecma262/#sec-string.prototype.big
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('big')
}, {
big: function big() {
return createHtml(this, 'big', '', '');
}
});
var es_string_big = {};
'use strict'; // `String.prototype.blink` method
// https://tc39.es/ecma262/#sec-string.prototype.blink
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('blink')
}, {
blink: function blink() {
return createHtml(this, 'blink', '', '');
}
});
var es_string_blink = {};
'use strict'; // `String.prototype.bold` method
// https://tc39.es/ecma262/#sec-string.prototype.bold
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('bold')
}, {
bold: function bold() {
return createHtml(this, 'b', '', '');
}
});
var es_string_bold = {};
'use strict'; // `String.prototype.fixed` method
// https://tc39.es/ecma262/#sec-string.prototype.fixed
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('fixed')
}, {
fixed: function fixed() {
return createHtml(this, 'tt', '', '');
}
});
var es_string_fixed = {};
'use strict'; // `String.prototype.fontcolor` method
// https://tc39.es/ecma262/#sec-string.prototype.fontcolor
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('fontcolor')
}, {
fontcolor: function fontcolor(color) {
return createHtml(this, 'font', 'color', color);
}
});
var es_string_fontcolor = {};
'use strict'; // `String.prototype.fontsize` method
// https://tc39.es/ecma262/#sec-string.prototype.fontsize
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('fontsize')
}, {
fontsize: function fontsize(size) {
return createHtml(this, 'font', 'size', size);
}
});
var es_string_fontsize = {};
'use strict'; // `String.prototype.italics` method
// https://tc39.es/ecma262/#sec-string.prototype.italics
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('italics')
}, {
italics: function italics() {
return createHtml(this, 'i', '', '');
}
});
var es_string_italics = {};
'use strict'; // `String.prototype.link` method
// https://tc39.es/ecma262/#sec-string.prototype.link
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('link')
}, {
link: function link(url) {
return createHtml(this, 'a', 'href', url);
}
});
var es_string_link = {};
'use strict'; // `String.prototype.small` method
// https://tc39.es/ecma262/#sec-string.prototype.small
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('small')
}, {
small: function small() {
return createHtml(this, 'small', '', '');
}
});
var es_string_small = {};
'use strict'; // `String.prototype.strike` method
// https://tc39.es/ecma262/#sec-string.prototype.strike
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('strike')
}, {
strike: function strike() {
return createHtml(this, 'strike', '', '');
}
});
var es_string_strike = {};
'use strict'; // `String.prototype.sub` method
// https://tc39.es/ecma262/#sec-string.prototype.sub
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('sub')
}, {
sub: function sub() {
return createHtml(this, 'sub', '', '');
}
});
var es_string_sub = {};
'use strict'; // `String.prototype.sup` method
// https://tc39.es/ecma262/#sec-string.prototype.sup
_export({
target: 'String',
proto: true,
forced: stringHtmlForced('sup')
}, {
sup: function sup() {
return createHtml(this, 'sup', '', '');
}
});
var es_string_sup = {};
'use strict';
var REPLACE$1 = wellKnownSymbol('replace');
var RegExpPrototype$1 = RegExp.prototype;
var max$4 = Math.max;
var stringIndexOf = function stringIndexOf(string, searchValue, fromIndex) {
if (fromIndex > string.length) return -1;
if (searchValue === '') return fromIndex;
return string.indexOf(searchValue, fromIndex);
}; // `String.prototype.replaceAll` method
// https://tc39.es/ecma262/#sec-string.prototype.replaceall
_export({
target: 'String',
proto: true
}, {
replaceAll: function replaceAll(searchValue, replaceValue) {
var O = requireObjectCoercible(this);
var IS_REG_EXP, flags, replacer, string, searchString, functionalReplace, searchLength, advanceBy, replacement;
var position = 0;
var endOfLastMatch = 0;
var result = '';
if (searchValue != null) {
IS_REG_EXP = isRegexp(searchValue);
if (IS_REG_EXP) {
flags = String(requireObjectCoercible('flags' in RegExpPrototype$1 ? searchValue.flags : regexpFlags.call(searchValue)));
if (!~flags.indexOf('g')) throw TypeError('`.replaceAll` does not allow non-global regexes');
}
replacer = searchValue[REPLACE$1];
if (replacer !== undefined) {
return replacer.call(searchValue, O, replaceValue);
} else if (isPure && IS_REG_EXP) {
return String(O).replace(searchValue, replaceValue);
}
}
string = String(O);
searchString = String(searchValue);
functionalReplace = typeof replaceValue === 'function';
if (!functionalReplace) replaceValue = String(replaceValue);
searchLength = searchString.length;
advanceBy = max$4(1, searchLength);
position = stringIndexOf(string, searchString, 0);
while (position !== -1) {
if (functionalReplace) {
replacement = String(replaceValue(searchString, position, string));
} else {
replacement = getSubstitution(searchString, string, position, [], undefined, replaceValue);
}
result += string.slice(endOfLastMatch, position) + replacement;
endOfLastMatch = position + searchLength;
position = stringIndexOf(string, searchString, position + advanceBy);
}
if (endOfLastMatch < string.length) {
result += string.slice(endOfLastMatch);
}
return result;
}
});
var es_string_replaceAll = {};
var inheritIfRequired = function inheritIfRequired($this, dummy, Wrapper) {
var NewTarget, NewTargetPrototype;
if ( // it can work only with native `setPrototypeOf`
objectSetPrototypeOf && // we haven't completely correct pre-ES6 way for getting `new.target`, so use this
typeof (NewTarget = dummy.constructor) == 'function' && NewTarget !== Wrapper && isObject(NewTargetPrototype = NewTarget.prototype) && NewTargetPrototype !== Wrapper.prototype) objectSetPrototypeOf($this, NewTargetPrototype);
return $this;
};
var defineProperty$5 = objectDefineProperty.f;
var getOwnPropertyNames = objectGetOwnPropertyNames.f;
var setInternalState$4 = internalState.set;
var MATCH$2 = wellKnownSymbol('match');
var NativeRegExp = global_1.RegExp;
var RegExpPrototype$2 = NativeRegExp.prototype;
var re1 = /a/g;
var re2 = /a/g; // "new" should create a new object, old webkit bug
var CORRECT_NEW = new NativeRegExp(re1) !== re1;
var UNSUPPORTED_Y$2 = regexpStickyHelpers.UNSUPPORTED_Y;
var FORCED$6 = descriptors && isForced_1('RegExp', !CORRECT_NEW || UNSUPPORTED_Y$2 || fails(function () {
re2[MATCH$2] = false; // RegExp constructor can alter flags and IsRegExp works correct with @@match
return NativeRegExp(re1) != re1 || NativeRegExp(re2) == re2 || NativeRegExp(re1, 'i') != '/a/i';
})); // `RegExp` constructor
// https://tc39.es/ecma262/#sec-regexp-constructor
if (FORCED$6) {
var RegExpWrapper = function RegExp(pattern, flags) {
var thisIsRegExp = this instanceof RegExpWrapper;
var patternIsRegExp = isRegexp(pattern);
var flagsAreUndefined = flags === undefined;
var sticky;
if (!thisIsRegExp && patternIsRegExp && pattern.constructor === RegExpWrapper && flagsAreUndefined) {
return pattern;
}
if (CORRECT_NEW) {
if (patternIsRegExp && !flagsAreUndefined) pattern = pattern.source;
} else if (pattern instanceof RegExpWrapper) {
if (flagsAreUndefined) flags = regexpFlags.call(pattern);
pattern = pattern.source;
}
if (UNSUPPORTED_Y$2) {
sticky = !!flags && flags.indexOf('y') > -1;
if (sticky) flags = flags.replace(/y/g, '');
}
var result = inheritIfRequired(CORRECT_NEW ? new NativeRegExp(pattern, flags) : NativeRegExp(pattern, flags), thisIsRegExp ? this : RegExpPrototype$2, RegExpWrapper);
if (UNSUPPORTED_Y$2 && sticky) setInternalState$4(result, {
sticky: sticky
});
return result;
};
var proxy = function proxy(key) {
key in RegExpWrapper || defineProperty$5(RegExpWrapper, key, {
configurable: true,
get: function get() {
return NativeRegExp[key];
},
set: function set(it) {
NativeRegExp[key] = it;
}
});
};
var keys$1 = getOwnPropertyNames(NativeRegExp);
var index = 0;
while (keys$1.length > index) {
proxy(keys$1[index++]);
}
RegExpPrototype$2.constructor = RegExpWrapper;
RegExpWrapper.prototype = RegExpPrototype$2;
redefine(global_1, 'RegExp', RegExpWrapper);
} // https://tc39.es/ecma262/#sec-get-regexp-@@species
setSpecies('RegExp');
var es_regexp_constructor = {};
var UNSUPPORTED_Y$3 = regexpStickyHelpers.UNSUPPORTED_Y; // `RegExp.prototype.flags` getter
// https://tc39.es/ecma262/#sec-get-regexp.prototype.flags
if (descriptors && (/./g.flags != 'g' || UNSUPPORTED_Y$3)) {
objectDefineProperty.f(RegExp.prototype, 'flags', {
configurable: true,
get: regexpFlags
});
}
var es_regexp_flags = {};
var UNSUPPORTED_Y$4 = regexpStickyHelpers.UNSUPPORTED_Y;
var defineProperty$6 = objectDefineProperty.f;
var getInternalState$4 = internalState.get;
var RegExpPrototype$3 = RegExp.prototype; // `RegExp.prototype.sticky` getter
// https://tc39.es/ecma262/#sec-get-regexp.prototype.sticky
if (descriptors && UNSUPPORTED_Y$4) {
defineProperty$6(RegExp.prototype, 'sticky', {
configurable: true,
get: function get() {
if (this === RegExpPrototype$3) return undefined; // We can't use InternalStateModule.getterFor because
// we don't add metadata for regexps created by a literal.
if (this instanceof RegExp) {
return !!getInternalState$4(this).sticky;
}
throw TypeError('Incompatible receiver, RegExp required');
}
});
}
var es_regexp_sticky = {};
'use strict'; // TODO: Remove from `core-js@4` since it's moved to entry points
var DELEGATES_TO_EXEC = function () {
var execCalled = false;
var re = /[ac]/;
re.exec = function () {
execCalled = true;
return /./.exec.apply(this, arguments);
};
return re.test('abc') === true && execCalled;
}();
var nativeTest = /./.test; // `RegExp.prototype.test` method
// https://tc39.es/ecma262/#sec-regexp.prototype.test
_export({
target: 'RegExp',
proto: true,
forced: !DELEGATES_TO_EXEC
}, {
test: function test(str) {
if (typeof this.exec !== 'function') {
return nativeTest.call(this, str);
}
var result = this.exec(str);
if (result !== null && !isObject(result)) {
throw new Error('RegExp exec method returned something other than an Object or null');
}
return !!result;
}
});
var es_regexp_test = {};
'use strict';
var TO_STRING = 'toString';
var RegExpPrototype$4 = RegExp.prototype;
var nativeToString = RegExpPrototype$4[TO_STRING];
var NOT_GENERIC = fails(function () {
return nativeToString.call({
source: 'a',
flags: 'b'
}) != '/a/b';
}); // FF44- RegExp#toString has a wrong name
var INCORRECT_NAME = nativeToString.name != TO_STRING; // `RegExp.prototype.toString` method
// https://tc39.es/ecma262/#sec-regexp.prototype.tostring
if (NOT_GENERIC || INCORRECT_NAME) {
redefine(RegExp.prototype, TO_STRING, function toString() {
var R = anObject(this);
var p = String(R.source);
var rf = R.flags;
var f = String(rf === undefined && R instanceof RegExp && !('flags' in RegExpPrototype$4) ? regexpFlags.call(R) : rf);
return '/' + p + '/' + f;
}, {
unsafe: true
});
}
var es_regexp_toString = {};
var trim = stringTrim.trim;
var $parseInt = global_1.parseInt;
var hex = /^[+-]?0[Xx]/;
var FORCED$7 = $parseInt(whitespaces + '08') !== 8 || $parseInt(whitespaces + '0x16') !== 22; // `parseInt` method
// https://tc39.es/ecma262/#sec-parseint-string-radix
var numberParseInt = FORCED$7 ? function parseInt(string, radix) {
var S = trim(String(string));
return $parseInt(S, radix >>> 0 || (hex.test(S) ? 16 : 10));
} : $parseInt;
// https://tc39.es/ecma262/#sec-parseint-string-radix
_export({
global: true,
forced: parseInt != numberParseInt
}, {
parseInt: numberParseInt
});
var es_parseInt = {};
var trim$1 = stringTrim.trim;
var $parseFloat = global_1.parseFloat;
var FORCED$8 = 1 / $parseFloat(whitespaces + '-0') !== -Infinity; // `parseFloat` method
// https://tc39.es/ecma262/#sec-parsefloat-string
var numberParseFloat = FORCED$8 ? function parseFloat(string) {
var trimmedString = trim$1(String(string));
var result = $parseFloat(trimmedString);
return result === 0 && trimmedString.charAt(0) == '-' ? -0 : result;
} : $parseFloat;
// https://tc39.es/ecma262/#sec-parsefloat-string
_export({
global: true,
forced: parseFloat != numberParseFloat
}, {
parseFloat: numberParseFloat
});
var es_parseFloat = {};
'use strict';
var getOwnPropertyNames$1 = objectGetOwnPropertyNames.f;
var getOwnPropertyDescriptor$6 = objectGetOwnPropertyDescriptor.f;
var defineProperty$7 = objectDefineProperty.f;
var trim$2 = stringTrim.trim;
var NUMBER = 'Number';
var NativeNumber = global_1[NUMBER];
var NumberPrototype = NativeNumber.prototype; // Opera ~12 has broken Object#toString
var BROKEN_CLASSOF = classofRaw(objectCreate(NumberPrototype)) == NUMBER; // `ToNumber` abstract operation
// https://tc39.es/ecma262/#sec-tonumber
var toNumber = function toNumber(argument) {
var it = toPrimitive(argument, false);
var first, third, radix, maxCode, digits, length, index, code;
if (typeof it == 'string' && it.length > 2) {
it = trim$2(it);
first = it.charCodeAt(0);
if (first === 43 || first === 45) {
third = it.charCodeAt(2);
if (third === 88 || third === 120) return NaN; // Number('+0x1') should be NaN, old V8 fix
} else if (first === 48) {
switch (it.charCodeAt(1)) {
case 66:
case 98:
radix = 2;
maxCode = 49;
break;
// fast equal of /^0b[01]+$/i
case 79:
case 111:
radix = 8;
maxCode = 55;
break;
// fast equal of /^0o[0-7]+$/i
default:
return +it;
}
digits = it.slice(2);
length = digits.length;
for (index = 0; index < length; index++) {
code = digits.charCodeAt(index); // parseInt parses a string to a first unavailable symbol
// but ToNumber should return NaN if a string contains unavailable symbols
if (code < 48 || code > maxCode) return NaN;
}
return parseInt(digits, radix);
}
}
return +it;
}; // `Number` constructor
// https://tc39.es/ecma262/#sec-number-constructor
if (isForced_1(NUMBER, !NativeNumber(' 0o1') || !NativeNumber('0b1') || NativeNumber('+0x1'))) {
var NumberWrapper = function Number(value) {
var it = arguments.length < 1 ? 0 : value;
var dummy = this;
return dummy instanceof NumberWrapper // check on 1..constructor(foo) case
&& (BROKEN_CLASSOF ? fails(function () {
NumberPrototype.valueOf.call(dummy);
}) : classofRaw(dummy) != NUMBER) ? inheritIfRequired(new NativeNumber(toNumber(it)), dummy, NumberWrapper) : toNumber(it);
};
for (var keys$2 = descriptors ? getOwnPropertyNames$1(NativeNumber) : ( // ES3:
'MAX_VALUE,MIN_VALUE,NaN,NEGATIVE_INFINITY,POSITIVE_INFINITY,' + // ES2015 (in case, if modules with ES2015 Number statics required before):
'EPSILON,isFinite,isInteger,isNaN,isSafeInteger,MAX_SAFE_INTEGER,' + 'MIN_SAFE_INTEGER,parseFloat,parseInt,isInteger,' + // ESNext
'fromString,range').split(','), j = 0, key; keys$2.length > j; j++) {
if (has(NativeNumber, key = keys$2[j]) && !has(NumberWrapper, key)) {
defineProperty$7(NumberWrapper, key, getOwnPropertyDescriptor$6(NativeNumber, key));
}
}
NumberWrapper.prototype = NumberPrototype;
NumberPrototype.constructor = NumberWrapper;
redefine(global_1, NUMBER, NumberWrapper);
}
var es_number_constructor = {};
// https://tc39.es/ecma262/#sec-number.epsilon
_export({
target: 'Number',
stat: true
}, {
EPSILON: Math.pow(2, -52)
});
var es_number_epsilon = {};
var globalIsFinite = global_1.isFinite; // `Number.isFinite` method
// https://tc39.es/ecma262/#sec-number.isfinite
var numberIsFinite = Number.isFinite || function isFinite(it) {
return typeof it == 'number' && globalIsFinite(it);
};
// https://tc39.es/ecma262/#sec-number.isfinite
_export({
target: 'Number',
stat: true
}, {
isFinite: numberIsFinite
});
var es_number_isFinite = {};
var floor$2 = Math.floor; // `Number.isInteger` method implementation
// https://tc39.es/ecma262/#sec-number.isinteger
var isInteger = function isInteger(it) {
return !isObject(it) && isFinite(it) && floor$2(it) === it;
};
// https://tc39.es/ecma262/#sec-number.isinteger
_export({
target: 'Number',
stat: true
}, {
isInteger: isInteger
});
var es_number_isInteger = {};
// https://tc39.es/ecma262/#sec-number.isnan
_export({
target: 'Number',
stat: true
}, {
isNaN: function isNaN(number) {
// eslint-disable-next-line no-self-compare -- NaN check
return number != number;
}
});
var es_number_isNan = {};
var abs = Math.abs; // `Number.isSafeInteger` method
// https://tc39.es/ecma262/#sec-number.issafeinteger
_export({
target: 'Number',
stat: true
}, {
isSafeInteger: function isSafeInteger(number) {
return isInteger(number) && abs(number) <= 0x1FFFFFFFFFFFFF;
}
});
var es_number_isSafeInteger = {};
// https://tc39.es/ecma262/#sec-number.max_safe_integer
_export({
target: 'Number',
stat: true
}, {
MAX_SAFE_INTEGER: 0x1FFFFFFFFFFFFF
});
var es_number_maxSafeInteger = {};
// https://tc39.es/ecma262/#sec-number.min_safe_integer
_export({
target: 'Number',
stat: true
}, {
MIN_SAFE_INTEGER: -0x1FFFFFFFFFFFFF
});
var es_number_minSafeInteger = {};
// https://tc39.es/ecma262/#sec-number.parseFloat
_export({
target: 'Number',
stat: true,
forced: Number.parseFloat != numberParseFloat
}, {
parseFloat: numberParseFloat
});
var es_number_parseFloat = {};
// https://tc39.es/ecma262/#sec-number.parseint
_export({
target: 'Number',
stat: true,
forced: Number.parseInt != numberParseInt
}, {
parseInt: numberParseInt
});
var es_number_parseInt = {};
// https://tc39.es/ecma262/#sec-thisnumbervalue
var thisNumberValue = function thisNumberValue(value) {
if (typeof value != 'number' && classofRaw(value) != 'Number') {
throw TypeError('Incorrect invocation');
}
return +value;
};
'use strict';
var nativeToFixed = 1.0.toFixed;
var floor$3 = Math.floor;
var pow = function pow(x, n, acc) {
return n === 0 ? acc : n % 2 === 1 ? pow(x, n - 1, acc * x) : pow(x * x, n / 2, acc);
};
var log = function log(x) {
var n = 0;
var x2 = x;
while (x2 >= 4096) {
n += 12;
x2 /= 4096;
}
while (x2 >= 2) {
n += 1;
x2 /= 2;
}
return n;
};
var multiply = function multiply(data, n, c) {
var index = -1;
var c2 = c;
while (++index < 6) {
c2 += n * data[index];
data[index] = c2 % 1e7;
c2 = floor$3(c2 / 1e7);
}
};
var divide = function divide(data, n) {
var index = 6;
var c = 0;
while (--index >= 0) {
c += data[index];
data[index] = floor$3(c / n);
c = c % n * 1e7;
}
};
var dataToString = function dataToString(data) {
var index = 6;
var s = '';
while (--index >= 0) {
if (s !== '' || index === 0 || data[index] !== 0) {
var t = String(data[index]);
s = s === '' ? t : s + stringRepeat.call('0', 7 - t.length) + t;
}
}
return s;
};
var FORCED$9 = nativeToFixed && (0.00008.toFixed(3) !== '0.000' || 0.9.toFixed(0) !== '1' || 1.255.toFixed(2) !== '1.25' || 1000000000000000128.0.toFixed(0) !== '1000000000000000128') || !fails(function () {
// V8 ~ Android 4.3-
nativeToFixed.call({});
}); // `Number.prototype.toFixed` method
// https://tc39.es/ecma262/#sec-number.prototype.tofixed
_export({
target: 'Number',
proto: true,
forced: FORCED$9
}, {
toFixed: function toFixed(fractionDigits) {
var number = thisNumberValue(this);
var fractDigits = toInteger(fractionDigits);
var data = [0, 0, 0, 0, 0, 0];
var sign = '';
var result = '0';
var e, z, j, k;
if (fractDigits < 0 || fractDigits > 20) throw RangeError('Incorrect fraction digits'); // eslint-disable-next-line no-self-compare -- NaN check
if (number != number) return 'NaN';
if (number <= -1e21 || number >= 1e21) return String(number);
if (number < 0) {
sign = '-';
number = -number;
}
if (number > 1e-21) {
e = log(number * pow(2, 69, 1)) - 69;
z = e < 0 ? number * pow(2, -e, 1) : number / pow(2, e, 1);
z *= 0x10000000000000;
e = 52 - e;
if (e > 0) {
multiply(data, 0, z);
j = fractDigits;
while (j >= 7) {
multiply(data, 1e7, 0);
j -= 7;
}
multiply(data, pow(10, j, 1), 0);
j = e - 1;
while (j >= 23) {
divide(data, 1 << 23);
j -= 23;
}
divide(data, 1 << j);
multiply(data, 1, 1);
divide(data, 2);
result = dataToString(data);
} else {
multiply(data, 0, z);
multiply(data, 1 << -e, 0);
result = dataToString(data) + stringRepeat.call('0', fractDigits);
}
}
if (fractDigits > 0) {
k = result.length;
result = sign + (k <= fractDigits ? '0.' + stringRepeat.call('0', fractDigits - k) + result : result.slice(0, k - fractDigits) + '.' + result.slice(k - fractDigits));
} else {
result = sign + result;
}
return result;
}
});
var es_number_toFixed = {};
'use strict';
var nativeToPrecision = 1.0.toPrecision;
var FORCED$a = fails(function () {
// IE7-
return nativeToPrecision.call(1, undefined) !== '1';
}) || !fails(function () {
// V8 ~ Android 4.3-
nativeToPrecision.call({});
}); // `Number.prototype.toPrecision` method
// https://tc39.es/ecma262/#sec-number.prototype.toprecision
_export({
target: 'Number',
proto: true,
forced: FORCED$a
}, {
toPrecision: function toPrecision(precision) {
return precision === undefined ? nativeToPrecision.call(thisNumberValue(this)) : nativeToPrecision.call(thisNumberValue(this), precision);
}
});
var es_number_toPrecision = {};
var log$1 = Math.log; // `Math.log1p` method implementation
// https://tc39.es/ecma262/#sec-math.log1p
var mathLog1p = Math.log1p || function log1p(x) {
return (x = +x) > -1e-8 && x < 1e-8 ? x - x * x / 2 : log$1(1 + x);
};
var nativeAcosh = Math.acosh;
var log$2 = Math.log;
var sqrt = Math.sqrt;
var LN2 = Math.LN2;
var FORCED$b = !nativeAcosh // V8 bug: https://code.google.com/p/v8/issues/detail?id=3509
|| Math.floor(nativeAcosh(Number.MAX_VALUE)) != 710 // Tor Browser bug: Math.acosh(Infinity) -> NaN
|| nativeAcosh(Infinity) != Infinity; // `Math.acosh` method
// https://tc39.es/ecma262/#sec-math.acosh
_export({
target: 'Math',
stat: true,
forced: FORCED$b
}, {
acosh: function acosh(x) {
return (x = +x) < 1 ? NaN : x > 94906265.62425156 ? log$2(x) + LN2 : mathLog1p(x - 1 + sqrt(x - 1) * sqrt(x + 1));
}
});
var es_math_acosh = {};
var nativeAsinh = Math.asinh;
var log$3 = Math.log;
var sqrt$1 = Math.sqrt;
function asinh(x) {
return !isFinite(x = +x) || x == 0 ? x : x < 0 ? -asinh(-x) : log$3(x + sqrt$1(x * x + 1));
} // `Math.asinh` method
// https://tc39.es/ecma262/#sec-math.asinh
// Tor Browser bug: Math.asinh(0) -> -0
_export({
target: 'Math',
stat: true,
forced: !(nativeAsinh && 1 / nativeAsinh(0) > 0)
}, {
asinh: asinh
});
var es_math_asinh = {};
var nativeAtanh = Math.atanh;
var log$4 = Math.log; // `Math.atanh` method
// https://tc39.es/ecma262/#sec-math.atanh
// Tor Browser bug: Math.atanh(-0) -> 0
_export({
target: 'Math',
stat: true,
forced: !(nativeAtanh && 1 / nativeAtanh(-0) < 0)
}, {
atanh: function atanh(x) {
return (x = +x) == 0 ? x : log$4((1 + x) / (1 - x)) / 2;
}
});
var es_math_atanh = {};
// `Math.sign` method implementation
// https://tc39.es/ecma262/#sec-math.sign
var mathSign = Math.sign || function sign(x) {
// eslint-disable-next-line no-self-compare -- NaN check
return (x = +x) == 0 || x != x ? x : x < 0 ? -1 : 1;
};
var abs$1 = Math.abs;
var pow$1 = Math.pow; // `Math.cbrt` method
// https://tc39.es/ecma262/#sec-math.cbrt
_export({
target: 'Math',
stat: true
}, {
cbrt: function cbrt(x) {
return mathSign(x = +x) * pow$1(abs$1(x), 1 / 3);
}
});
var es_math_cbrt = {};
var floor$4 = Math.floor;
var log$5 = Math.log;
var LOG2E = Math.LOG2E; // `Math.clz32` method
// https://tc39.es/ecma262/#sec-math.clz32
_export({
target: 'Math',
stat: true
}, {
clz32: function clz32(x) {
return (x >>>= 0) ? 31 - floor$4(log$5(x + 0.5) * LOG2E) : 32;
}
});
var es_math_clz32 = {};
var nativeExpm1 = Math.expm1;
var exp = Math.exp; // `Math.expm1` method implementation
// https://tc39.es/ecma262/#sec-math.expm1
var mathExpm1 = !nativeExpm1 // Old FF bug
|| nativeExpm1(10) > 22025.465794806719 || nativeExpm1(10) < 22025.4657948067165168 // Tor Browser bug
|| nativeExpm1(-2e-17) != -2e-17 ? function expm1(x) {
return (x = +x) == 0 ? x : x > -1e-6 && x < 1e-6 ? x + x * x / 2 : exp(x) - 1;
} : nativeExpm1;
var nativeCosh = Math.cosh;
var abs$2 = Math.abs;
var E = Math.E; // `Math.cosh` method
// https://tc39.es/ecma262/#sec-math.cosh
_export({
target: 'Math',
stat: true,
forced: !nativeCosh || nativeCosh(710) === Infinity
}, {
cosh: function cosh(x) {
var t = mathExpm1(abs$2(x) - 1) + 1;
return (t + 1 / (t * E * E)) * (E / 2);
}
});
var es_math_cosh = {};
// https://tc39.es/ecma262/#sec-math.expm1
_export({
target: 'Math',
stat: true,
forced: mathExpm1 != Math.expm1
}, {
expm1: mathExpm1
});
var es_math_expm1 = {};
var abs$3 = Math.abs;
var pow$2 = Math.pow;
var EPSILON = pow$2(2, -52);
var EPSILON32 = pow$2(2, -23);
var MAX32 = pow$2(2, 127) * (2 - EPSILON32);
var MIN32 = pow$2(2, -126);
var roundTiesToEven = function roundTiesToEven(n) {
return n + 1 / EPSILON - 1 / EPSILON;
}; // `Math.fround` method implementation
// https://tc39.es/ecma262/#sec-math.fround
var mathFround = Math.fround || function fround(x) {
var $abs = abs$3(x);
var $sign = mathSign(x);
var a, result;
if ($abs < MIN32) return $sign * roundTiesToEven($abs / MIN32 / EPSILON32) * MIN32 * EPSILON32;
a = (1 + EPSILON32 / EPSILON) * $abs;
result = a - (a - $abs); // eslint-disable-next-line no-self-compare -- NaN check
if (result > MAX32 || result != result) return $sign * Infinity;
return $sign * result;
};
// https://tc39.es/ecma262/#sec-math.fround
_export({
target: 'Math',
stat: true
}, {
fround: mathFround
});
var es_math_fround = {};
var $hypot = Math.hypot;
var abs$4 = Math.abs;
var sqrt$2 = Math.sqrt; // Chrome 77 bug
// https://bugs.chromium.org/p/v8/issues/detail?id=9546
var BUGGY = !!$hypot && $hypot(Infinity, NaN) !== Infinity; // `Math.hypot` method
// https://tc39.es/ecma262/#sec-math.hypot
_export({
target: 'Math',
stat: true,
forced: BUGGY
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
hypot: function hypot(value1, value2) {
var sum = 0;
var i = 0;
var aLen = arguments.length;
var larg = 0;
var arg, div;
while (i < aLen) {
arg = abs$4(arguments[i++]);
if (larg < arg) {
div = larg / arg;
sum = sum * div * div + 1;
larg = arg;
} else if (arg > 0) {
div = arg / larg;
sum += div * div;
} else sum += arg;
}
return larg === Infinity ? Infinity : larg * sqrt$2(sum);
}
});
var es_math_hypot = {};
var nativeImul = Math.imul;
var FORCED$c = fails(function () {
return nativeImul(0xFFFFFFFF, 5) != -5 || nativeImul.length != 2;
}); // `Math.imul` method
// https://tc39.es/ecma262/#sec-math.imul
// some WebKit versions fails with big numbers, some has wrong arity
_export({
target: 'Math',
stat: true,
forced: FORCED$c
}, {
imul: function imul(x, y) {
var UINT16 = 0xFFFF;
var xn = +x;
var yn = +y;
var xl = UINT16 & xn;
var yl = UINT16 & yn;
return 0 | xl * yl + ((UINT16 & xn >>> 16) * yl + xl * (UINT16 & yn >>> 16) << 16 >>> 0);
}
});
var es_math_imul = {};
var log$6 = Math.log;
var LOG10E = Math.LOG10E; // `Math.log10` method
// https://tc39.es/ecma262/#sec-math.log10
_export({
target: 'Math',
stat: true
}, {
log10: function log10(x) {
return log$6(x) * LOG10E;
}
});
var es_math_log10 = {};
// https://tc39.es/ecma262/#sec-math.log1p
_export({
target: 'Math',
stat: true
}, {
log1p: mathLog1p
});
var es_math_log1p = {};
var log$7 = Math.log;
var LN2$1 = Math.LN2; // `Math.log2` method
// https://tc39.es/ecma262/#sec-math.log2
_export({
target: 'Math',
stat: true
}, {
log2: function log2(x) {
return log$7(x) / LN2$1;
}
});
var es_math_log2 = {};
// https://tc39.es/ecma262/#sec-math.sign
_export({
target: 'Math',
stat: true
}, {
sign: mathSign
});
var es_math_sign = {};
var abs$5 = Math.abs;
var exp$1 = Math.exp;
var E$1 = Math.E;
var FORCED$d = fails(function () {
return Math.sinh(-2e-17) != -2e-17;
}); // `Math.sinh` method
// https://tc39.es/ecma262/#sec-math.sinh
// V8 near Chromium 38 has a problem with very small numbers
_export({
target: 'Math',
stat: true,
forced: FORCED$d
}, {
sinh: function sinh(x) {
return abs$5(x = +x) < 1 ? (mathExpm1(x) - mathExpm1(-x)) / 2 : (exp$1(x - 1) - exp$1(-x - 1)) * (E$1 / 2);
}
});
var es_math_sinh = {};
var exp$2 = Math.exp; // `Math.tanh` method
// https://tc39.es/ecma262/#sec-math.tanh
_export({
target: 'Math',
stat: true
}, {
tanh: function tanh(x) {
var a = mathExpm1(x = +x);
var b = mathExpm1(-x);
return a == Infinity ? 1 : b == Infinity ? -1 : (a - b) / (exp$2(x) + exp$2(-x));
}
});
var es_math_tanh = {};
// https://tc39.es/ecma262/#sec-math-@@tostringtag
setToStringTag(Math, 'Math', true);
var es_math_toStringTag = {};
var ceil$2 = Math.ceil;
var floor$5 = Math.floor; // `Math.trunc` method
// https://tc39.es/ecma262/#sec-math.trunc
_export({
target: 'Math',
stat: true
}, {
trunc: function trunc(it) {
return (it > 0 ? floor$5 : ceil$2)(it);
}
});
var es_math_trunc = {};
// https://tc39.es/ecma262/#sec-date.now
_export({
target: 'Date',
stat: true
}, {
now: function now() {
return new Date().getTime();
}
});
var es_date_now = {};
'use strict';
var FORCED$e = fails(function () {
return new Date(NaN).toJSON() !== null || Date.prototype.toJSON.call({
toISOString: function toISOString() {
return 1;
}
}) !== 1;
}); // `Date.prototype.toJSON` method
// https://tc39.es/ecma262/#sec-date.prototype.tojson
_export({
target: 'Date',
proto: true,
forced: FORCED$e
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
toJSON: function toJSON(key) {
var O = toObject(this);
var pv = toPrimitive(O);
return typeof pv == 'number' && !isFinite(pv) ? null : O.toISOString();
}
});
var es_date_toJson = {};
'use strict';
var padStart = stringPad.start;
var abs$6 = Math.abs;
var DatePrototype = Date.prototype;
var getTime = DatePrototype.getTime;
var nativeDateToISOString = DatePrototype.toISOString; // `Date.prototype.toISOString` method implementation
// https://tc39.es/ecma262/#sec-date.prototype.toisostring
// PhantomJS / old WebKit fails here:
var dateToIsoString = fails(function () {
return nativeDateToISOString.call(new Date(-5e13 - 1)) != '0385-07-25T07:06:39.999Z';
}) || !fails(function () {
nativeDateToISOString.call(new Date(NaN));
}) ? function toISOString() {
if (!isFinite(getTime.call(this))) throw RangeError('Invalid time value');
var date = this;
var year = date.getUTCFullYear();
var milliseconds = date.getUTCMilliseconds();
var sign = year < 0 ? '-' : year > 9999 ? '+' : '';
return sign + padStart(abs$6(year), sign ? 6 : 4, 0) + '-' + padStart(date.getUTCMonth() + 1, 2, 0) + '-' + padStart(date.getUTCDate(), 2, 0) + 'T' + padStart(date.getUTCHours(), 2, 0) + ':' + padStart(date.getUTCMinutes(), 2, 0) + ':' + padStart(date.getUTCSeconds(), 2, 0) + '.' + padStart(milliseconds, 3, 0) + 'Z';
} : nativeDateToISOString;
// https://tc39.es/ecma262/#sec-date.prototype.toisostring
// PhantomJS / old WebKit has a broken implementations
_export({
target: 'Date',
proto: true,
forced: Date.prototype.toISOString !== dateToIsoString
}, {
toISOString: dateToIsoString
});
var es_date_toIsoString = {};
var DatePrototype$1 = Date.prototype;
var INVALID_DATE = 'Invalid Date';
var TO_STRING$1 = 'toString';
var nativeDateToString = DatePrototype$1[TO_STRING$1];
var getTime$1 = DatePrototype$1.getTime; // `Date.prototype.toString` method
// https://tc39.es/ecma262/#sec-date.prototype.tostring
if (new Date(NaN) + '' != INVALID_DATE) {
redefine(DatePrototype$1, TO_STRING$1, function toString() {
var value = getTime$1.call(this); // eslint-disable-next-line no-self-compare -- NaN check
return value === value ? nativeDateToString.call(this) : INVALID_DATE;
});
}
var es_date_toString = {};
'use strict';
var dateToPrimitive = function dateToPrimitive(hint) {
if (hint !== 'string' && hint !== 'number' && hint !== 'default') {
throw TypeError('Incorrect hint');
}
return toPrimitive(anObject(this), hint !== 'number');
};
var TO_PRIMITIVE$1 = wellKnownSymbol('toPrimitive');
var DatePrototype$2 = Date.prototype; // `Date.prototype[@@toPrimitive]` method
// https://tc39.es/ecma262/#sec-date.prototype-@@toprimitive
if (!(TO_PRIMITIVE$1 in DatePrototype$2)) {
createNonEnumerableProperty(DatePrototype$2, TO_PRIMITIVE$1, dateToPrimitive);
}
var es_date_toPrimitive = {};
var $stringify$1 = getBuiltIn('JSON', 'stringify');
var re = /[\uD800-\uDFFF]/g;
var low = /^[\uD800-\uDBFF]$/;
var hi = /^[\uDC00-\uDFFF]$/;
var fix = function fix(match, offset, string) {
var prev = string.charAt(offset - 1);
var next = string.charAt(offset + 1);
if (low.test(match) && !hi.test(next) || hi.test(match) && !low.test(prev)) {
return "\\u" + match.charCodeAt(0).toString(16);
}
return match;
};
var FORCED$f = fails(function () {
return $stringify$1("\uDF06\uD834") !== "\"\\udf06\\ud834\"" || $stringify$1("\uDEAD") !== "\"\\udead\"";
});
if ($stringify$1) {
// `JSON.stringify` method
// https://tc39.es/ecma262/#sec-json.stringify
// https://github.com/tc39/proposal-well-formed-stringify
_export({
target: 'JSON',
stat: true,
forced: FORCED$f
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
stringify: function stringify(it, replacer, space) {
var result = $stringify$1.apply(null, arguments);
return typeof result == 'string' ? result.replace(re, fix) : result;
}
});
}
var es_json_stringify = {};
// https://tc39.es/ecma262/#sec-json-@@tostringtag
setToStringTag(global_1.JSON, 'JSON', true);
var es_json_toStringTag = {};
var nativePromiseConstructor = global_1.Promise;
var redefineAll = function redefineAll(target, src, options) {
for (var key in src) {
redefine(target, key, src[key], options);
}
return target;
};
var anInstance = function anInstance(it, Constructor, name) {
if (!(it instanceof Constructor)) {
throw TypeError('Incorrect ' + (name ? name + ' ' : '') + 'invocation');
}
return it;
};
var engineIsIos = /(iphone|ipod|ipad).*applewebkit/i.test(engineUserAgent);
var location = global_1.location;
var set$1 = global_1.setImmediate;
var clear = global_1.clearImmediate;
var process$2 = global_1.process;
var MessageChannel = global_1.MessageChannel;
var Dispatch = global_1.Dispatch;
var counter = 0;
var queue = {};
var ONREADYSTATECHANGE = 'onreadystatechange';
var defer, channel, port;
var run = function run(id) {
// eslint-disable-next-line no-prototype-builtins -- safe
if (queue.hasOwnProperty(id)) {
var fn = queue[id];
delete queue[id];
fn();
}
};
var runner = function runner(id) {
return function () {
run(id);
};
};
var listener = function listener(event) {
run(event.data);
};
var post = function post(id) {
// old engines have not location.origin
global_1.postMessage(id + '', location.protocol + '//' + location.host);
}; // Node.js 0.9+ & IE10+ has setImmediate, otherwise:
if (!set$1 || !clear) {
set$1 = function setImmediate(fn) {
var args = [];
var i = 1;
while (arguments.length > i) {
args.push(arguments[i++]);
}
queue[++counter] = function () {
// eslint-disable-next-line no-new-func -- spec requirement
(typeof fn == 'function' ? fn : Function(fn)).apply(undefined, args);
};
defer(counter);
return counter;
};
clear = function clearImmediate(id) {
delete queue[id];
}; // Node.js 0.8-
if (engineIsNode) {
defer = function defer(id) {
process$2.nextTick(runner(id));
}; // Sphere (JS game engine) Dispatch API
} else if (Dispatch && Dispatch.now) {
defer = function defer(id) {
Dispatch.now(runner(id));
}; // Browsers with MessageChannel, includes WebWorkers
// except iOS - https://github.com/zloirock/core-js/issues/624
} else if (MessageChannel && !engineIsIos) {
channel = new MessageChannel();
port = channel.port2;
channel.port1.onmessage = listener;
defer = functionBindContext(port.postMessage, port, 1); // Browsers with postMessage, skip WebWorkers
// IE8 has postMessage, but it's sync & typeof its postMessage is 'object'
} else if (global_1.addEventListener && typeof postMessage == 'function' && !global_1.importScripts && location && location.protocol !== 'file:' && !fails(post)) {
defer = post;
global_1.addEventListener('message', listener, false); // IE8-
} else if (ONREADYSTATECHANGE in documentCreateElement('script')) {
defer = function defer(id) {
html.appendChild(documentCreateElement('script'))[ONREADYSTATECHANGE] = function () {
html.removeChild(this);
run(id);
};
}; // Rest old browsers
} else {
defer = function defer(id) {
setTimeout(runner(id), 0);
};
}
}
var task = {
set: set$1,
clear: clear
};
var task_1 = task.set;
var task_2 = task.clear;
var engineIsWebosWebkit = /web0s(?!.*chrome)/i.test(engineUserAgent);
var getOwnPropertyDescriptor$7 = objectGetOwnPropertyDescriptor.f;
var macrotask = task.set;
var MutationObserver = global_1.MutationObserver || global_1.WebKitMutationObserver;
var document$2 = global_1.document;
var process$3 = global_1.process;
var Promise$1 = global_1.Promise; // Node.js 11 shows ExperimentalWarning on getting `queueMicrotask`
var queueMicrotaskDescriptor = getOwnPropertyDescriptor$7(global_1, 'queueMicrotask');
var queueMicrotask = queueMicrotaskDescriptor && queueMicrotaskDescriptor.value;
var flush, head, last, notify, toggle, node, promise, then; // modern engines have queueMicrotask method
if (!queueMicrotask) {
flush = function flush() {
var parent, fn;
if (engineIsNode && (parent = process$3.domain)) parent.exit();
while (head) {
fn = head.fn;
head = head.next;
try {
fn();
} catch (error) {
if (head) notify();else last = undefined;
throw error;
}
}
last = undefined;
if (parent) parent.enter();
}; // browsers with MutationObserver, except iOS - https://github.com/zloirock/core-js/issues/339
// also except WebOS Webkit https://github.com/zloirock/core-js/issues/898
if (!engineIsIos && !engineIsNode && !engineIsWebosWebkit && MutationObserver && document$2) {
toggle = true;
node = document$2.createTextNode('');
new MutationObserver(flush).observe(node, {
characterData: true
});
notify = function notify() {
node.data = toggle = !toggle;
}; // environments with maybe non-completely correct, but existent Promise
} else if (Promise$1 && Promise$1.resolve) {
// Promise.resolve without an argument throws an error in LG WebOS 2
promise = Promise$1.resolve(undefined);
then = promise.then;
notify = function notify() {
then.call(promise, flush);
}; // Node.js without promises
} else if (engineIsNode) {
notify = function notify() {
process$3.nextTick(flush);
}; // for other environments - macrotask based on:
// - setImmediate
// - MessageChannel
// - window.postMessag
// - onreadystatechange
// - setTimeout
} else {
notify = function notify() {
// strange IE + webpack dev server bug - use .call(global)
macrotask.call(global_1, flush);
};
}
}
var microtask = queueMicrotask || function (fn) {
var task = {
fn: fn,
next: undefined
};
if (last) last.next = task;
if (!head) {
head = task;
notify();
}
last = task;
};
'use strict';
var PromiseCapability = function PromiseCapability(C) {
var resolve, reject;
this.promise = new C(function ($$resolve, $$reject) {
if (resolve !== undefined || reject !== undefined) throw TypeError('Bad Promise constructor');
resolve = $$resolve;
reject = $$reject;
});
this.resolve = aFunction$1(resolve);
this.reject = aFunction$1(reject);
}; // 25.4.1.5 NewPromiseCapability(C)
var f$7 = function f(C) {
return new PromiseCapability(C);
};
var newPromiseCapability = {
f: f$7
};
var promiseResolve = function promiseResolve(C, x) {
anObject(C);
if (isObject(x) && x.constructor === C) return x;
var promiseCapability = newPromiseCapability.f(C);
var resolve = promiseCapability.resolve;
resolve(x);
return promiseCapability.promise;
};
var hostReportErrors = function hostReportErrors(a, b) {
var console = global_1.console;
if (console && console.error) {
arguments.length === 1 ? console.error(a) : console.error(a, b);
}
};
var perform = function perform(exec) {
try {
return {
error: false,
value: exec()
};
} catch (error) {
return {
error: true,
value: error
};
}
};
'use strict';
var task$1 = task.set;
var SPECIES$6 = wellKnownSymbol('species');
var PROMISE = 'Promise';
var getInternalState$5 = internalState.get;
var setInternalState$5 = internalState.set;
var getInternalPromiseState = internalState.getterFor(PROMISE);
var PromiseConstructor = nativePromiseConstructor;
var TypeError$1 = global_1.TypeError;
var document$3 = global_1.document;
var process$4 = global_1.process;
var $fetch = getBuiltIn('fetch');
var newPromiseCapability$1 = newPromiseCapability.f;
var newGenericPromiseCapability = newPromiseCapability$1;
var DISPATCH_EVENT = !!(document$3 && document$3.createEvent && global_1.dispatchEvent);
var NATIVE_REJECTION_EVENT = typeof PromiseRejectionEvent == 'function';
var UNHANDLED_REJECTION = 'unhandledrejection';
var REJECTION_HANDLED = 'rejectionhandled';
var PENDING = 0;
var FULFILLED = 1;
var REJECTED = 2;
var HANDLED = 1;
var UNHANDLED = 2;
var Internal, OwnPromiseCapability, PromiseWrapper, nativeThen;
var FORCED$g = isForced_1(PROMISE, function () {
var GLOBAL_CORE_JS_PROMISE = inspectSource(PromiseConstructor) !== String(PromiseConstructor);
if (!GLOBAL_CORE_JS_PROMISE) {
// V8 6.6 (Node 10 and Chrome 66) have a bug with resolving custom thenables
// https://bugs.chromium.org/p/chromium/issues/detail?id=830565
// We can't detect it synchronously, so just check versions
if (engineV8Version === 66) return true; // Unhandled rejections tracking support, NodeJS Promise without it fails @@species test
if (!engineIsNode && !NATIVE_REJECTION_EVENT) return true;
} // We need Promise#finally in the pure version for preventing prototype pollution
if (isPure && !PromiseConstructor.prototype['finally']) return true; // We can't use @@species feature detection in V8 since it causes
// deoptimization and performance degradation
// https://github.com/zloirock/core-js/issues/679
if (engineV8Version >= 51 && /native code/.test(PromiseConstructor)) return false; // Detect correctness of subclassing with @@species support
var promise = PromiseConstructor.resolve(1);
var FakePromise = function FakePromise(exec) {
exec(function () {
/* empty */
}, function () {
/* empty */
});
};
var constructor = promise.constructor = {};
constructor[SPECIES$6] = FakePromise;
return !(promise.then(function () {
/* empty */
}) instanceof FakePromise);
});
var INCORRECT_ITERATION$1 = FORCED$g || !checkCorrectnessOfIteration(function (iterable) {
PromiseConstructor.all(iterable)['catch'](function () {
/* empty */
});
}); // helpers
var isThenable = function isThenable(it) {
var then;
return isObject(it) && typeof (then = it.then) == 'function' ? then : false;
};
var notify$1 = function notify(state, isReject) {
if (state.notified) return;
state.notified = true;
var chain = state.reactions;
microtask(function () {
var value = state.value;
var ok = state.state == FULFILLED;
var index = 0; // variable length - can't use forEach
while (chain.length > index) {
var reaction = chain[index++];
var handler = ok ? reaction.ok : reaction.fail;
var resolve = reaction.resolve;
var reject = reaction.reject;
var domain = reaction.domain;
var result, then, exited;
try {
if (handler) {
if (!ok) {
if (state.rejection === UNHANDLED) onHandleUnhandled(state);
state.rejection = HANDLED;
}
if (handler === true) result = value;else {
if (domain) domain.enter();
result = handler(value); // can throw
if (domain) {
domain.exit();
exited = true;
}
}
if (result === reaction.promise) {
reject(TypeError$1('Promise-chain cycle'));
} else if (then = isThenable(result)) {
then.call(result, resolve, reject);
} else resolve(result);
} else reject(value);
} catch (error) {
if (domain && !exited) domain.exit();
reject(error);
}
}
state.reactions = [];
state.notified = false;
if (isReject && !state.rejection) onUnhandled(state);
});
};
var dispatchEvent = function dispatchEvent(name, promise, reason) {
var event, handler;
if (DISPATCH_EVENT) {
event = document$3.createEvent('Event');
event.promise = promise;
event.reason = reason;
event.initEvent(name, false, true);
global_1.dispatchEvent(event);
} else event = {
promise: promise,
reason: reason
};
if (!NATIVE_REJECTION_EVENT && (handler = global_1['on' + name])) handler(event);else if (name === UNHANDLED_REJECTION) hostReportErrors('Unhandled promise rejection', reason);
};
var onUnhandled = function onUnhandled(state) {
task$1.call(global_1, function () {
var promise = state.facade;
var value = state.value;
var IS_UNHANDLED = isUnhandled(state);
var result;
if (IS_UNHANDLED) {
result = perform(function () {
if (engineIsNode) {
process$4.emit('unhandledRejection', value, promise);
} else dispatchEvent(UNHANDLED_REJECTION, promise, value);
}); // Browsers should not trigger `rejectionHandled` event if it was handled here, NodeJS - should
state.rejection = engineIsNode || isUnhandled(state) ? UNHANDLED : HANDLED;
if (result.error) throw result.value;
}
});
};
var isUnhandled = function isUnhandled(state) {
return state.rejection !== HANDLED && !state.parent;
};
var onHandleUnhandled = function onHandleUnhandled(state) {
task$1.call(global_1, function () {
var promise = state.facade;
if (engineIsNode) {
process$4.emit('rejectionHandled', promise);
} else dispatchEvent(REJECTION_HANDLED, promise, state.value);
});
};
var bind = function bind(fn, state, unwrap) {
return function (value) {
fn(state, value, unwrap);
};
};
var internalReject = function internalReject(state, value, unwrap) {
if (state.done) return;
state.done = true;
if (unwrap) state = unwrap;
state.value = value;
state.state = REJECTED;
notify$1(state, true);
};
var internalResolve = function internalResolve(state, value, unwrap) {
if (state.done) return;
state.done = true;
if (unwrap) state = unwrap;
try {
if (state.facade === value) throw TypeError$1("Promise can't be resolved itself");
var then = isThenable(value);
if (then) {
microtask(function () {
var wrapper = {
done: false
};
try {
then.call(value, bind(internalResolve, wrapper, state), bind(internalReject, wrapper, state));
} catch (error) {
internalReject(wrapper, error, state);
}
});
} else {
state.value = value;
state.state = FULFILLED;
notify$1(state, false);
}
} catch (error) {
internalReject({
done: false
}, error, state);
}
}; // constructor polyfill
if (FORCED$g) {
// 25.4.3.1 Promise(executor)
PromiseConstructor = function Promise(executor) {
anInstance(this, PromiseConstructor, PROMISE);
aFunction$1(executor);
Internal.call(this);
var state = getInternalState$5(this);
try {
executor(bind(internalResolve, state), bind(internalReject, state));
} catch (error) {
internalReject(state, error);
}
}; // eslint-disable-next-line no-unused-vars -- required for `.length`
Internal = function Promise(executor) {
setInternalState$5(this, {
type: PROMISE,
done: false,
notified: false,
parent: false,
reactions: [],
rejection: false,
state: PENDING,
value: undefined
});
};
Internal.prototype = redefineAll(PromiseConstructor.prototype, {
// `Promise.prototype.then` method
// https://tc39.es/ecma262/#sec-promise.prototype.then
then: function then(onFulfilled, onRejected) {
var state = getInternalPromiseState(this);
var reaction = newPromiseCapability$1(speciesConstructor(this, PromiseConstructor));
reaction.ok = typeof onFulfilled == 'function' ? onFulfilled : true;
reaction.fail = typeof onRejected == 'function' && onRejected;
reaction.domain = engineIsNode ? process$4.domain : undefined;
state.parent = true;
state.reactions.push(reaction);
if (state.state != PENDING) notify$1(state, false);
return reaction.promise;
},
// `Promise.prototype.catch` method
// https://tc39.es/ecma262/#sec-promise.prototype.catch
'catch': function _catch(onRejected) {
return this.then(undefined, onRejected);
}
});
OwnPromiseCapability = function OwnPromiseCapability() {
var promise = new Internal();
var state = getInternalState$5(promise);
this.promise = promise;
this.resolve = bind(internalResolve, state);
this.reject = bind(internalReject, state);
};
newPromiseCapability.f = newPromiseCapability$1 = function newPromiseCapability(C) {
return C === PromiseConstructor || C === PromiseWrapper ? new OwnPromiseCapability(C) : newGenericPromiseCapability(C);
};
if (!isPure && typeof nativePromiseConstructor == 'function') {
nativeThen = nativePromiseConstructor.prototype.then; // wrap native Promise#then for native async functions
redefine(nativePromiseConstructor.prototype, 'then', function then(onFulfilled, onRejected) {
var that = this;
return new PromiseConstructor(function (resolve, reject) {
nativeThen.call(that, resolve, reject);
}).then(onFulfilled, onRejected); // https://github.com/zloirock/core-js/issues/640
}, {
unsafe: true
}); // wrap fetch result
if (typeof $fetch == 'function') _export({
global: true,
enumerable: true,
forced: true
}, {
// eslint-disable-next-line no-unused-vars -- required for `.length`
fetch: function fetch(input
/* , init */
) {
return promiseResolve(PromiseConstructor, $fetch.apply(global_1, arguments));
}
});
}
}
_export({
global: true,
wrap: true,
forced: FORCED$g
}, {
Promise: PromiseConstructor
});
setToStringTag(PromiseConstructor, PROMISE, false, true);
setSpecies(PROMISE);
PromiseWrapper = getBuiltIn(PROMISE); // statics
_export({
target: PROMISE,
stat: true,
forced: FORCED$g
}, {
// `Promise.reject` method
// https://tc39.es/ecma262/#sec-promise.reject
reject: function reject(r) {
var capability = newPromiseCapability$1(this);
capability.reject.call(undefined, r);
return capability.promise;
}
});
_export({
target: PROMISE,
stat: true,
forced: isPure || FORCED$g
}, {
// `Promise.resolve` method
// https://tc39.es/ecma262/#sec-promise.resolve
resolve: function resolve(x) {
return promiseResolve(isPure && this === PromiseWrapper ? PromiseConstructor : this, x);
}
});
_export({
target: PROMISE,
stat: true,
forced: INCORRECT_ITERATION$1
}, {
// `Promise.all` method
// https://tc39.es/ecma262/#sec-promise.all
all: function all(iterable) {
var C = this;
var capability = newPromiseCapability$1(C);
var resolve = capability.resolve;
var reject = capability.reject;
var result = perform(function () {
var $promiseResolve = aFunction$1(C.resolve);
var values = [];
var counter = 0;
var remaining = 1;
iterate(iterable, function (promise) {
var index = counter++;
var alreadyCalled = false;
values.push(undefined);
remaining++;
$promiseResolve.call(C, promise).then(function (value) {
if (alreadyCalled) return;
alreadyCalled = true;
values[index] = value;
--remaining || resolve(values);
}, reject);
});
--remaining || resolve(values);
});
if (result.error) reject(result.value);
return capability.promise;
},
// `Promise.race` method
// https://tc39.es/ecma262/#sec-promise.race
race: function race(iterable) {
var C = this;
var capability = newPromiseCapability$1(C);
var reject = capability.reject;
var result = perform(function () {
var $promiseResolve = aFunction$1(C.resolve);
iterate(iterable, function (promise) {
$promiseResolve.call(C, promise).then(capability.resolve, reject);
});
});
if (result.error) reject(result.value);
return capability.promise;
}
});
var es_promise = {};
'use strict'; // `Promise.allSettled` method
// https://tc39.es/ecma262/#sec-promise.allsettled
_export({
target: 'Promise',
stat: true
}, {
allSettled: function allSettled(iterable) {
var C = this;
var capability = newPromiseCapability.f(C);
var resolve = capability.resolve;
var reject = capability.reject;
var result = perform(function () {
var promiseResolve = aFunction$1(C.resolve);
var values = [];
var counter = 0;
var remaining = 1;
iterate(iterable, function (promise) {
var index = counter++;
var alreadyCalled = false;
values.push(undefined);
remaining++;
promiseResolve.call(C, promise).then(function (value) {
if (alreadyCalled) return;
alreadyCalled = true;
values[index] = {
status: 'fulfilled',
value: value
};
--remaining || resolve(values);
}, function (error) {
if (alreadyCalled) return;
alreadyCalled = true;
values[index] = {
status: 'rejected',
reason: error
};
--remaining || resolve(values);
});
});
--remaining || resolve(values);
});
if (result.error) reject(result.value);
return capability.promise;
}
});
var es_promise_allSettled = {};
'use strict';
var PROMISE_ANY_ERROR = 'No one promise resolved'; // `Promise.any` method
// https://tc39.es/ecma262/#sec-promise.any
_export({
target: 'Promise',
stat: true
}, {
any: function any(iterable) {
var C = this;
var capability = newPromiseCapability.f(C);
var resolve = capability.resolve;
var reject = capability.reject;
var result = perform(function () {
var promiseResolve = aFunction$1(C.resolve);
var errors = [];
var counter = 0;
var remaining = 1;
var alreadyResolved = false;
iterate(iterable, function (promise) {
var index = counter++;
var alreadyRejected = false;
errors.push(undefined);
remaining++;
promiseResolve.call(C, promise).then(function (value) {
if (alreadyRejected || alreadyResolved) return;
alreadyResolved = true;
resolve(value);
}, function (error) {
if (alreadyRejected || alreadyResolved) return;
alreadyRejected = true;
errors[index] = error;
--remaining || reject(new (getBuiltIn('AggregateError'))(errors, PROMISE_ANY_ERROR));
});
});
--remaining || reject(new (getBuiltIn('AggregateError'))(errors, PROMISE_ANY_ERROR));
});
if (result.error) reject(result.value);
return capability.promise;
}
});
var es_promise_any = {};
'use strict'; // Safari bug https://bugs.webkit.org/show_bug.cgi?id=200829
var NON_GENERIC = !!nativePromiseConstructor && fails(function () {
nativePromiseConstructor.prototype['finally'].call({
then: function then() {
/* empty */
}
}, function () {
/* empty */
});
}); // `Promise.prototype.finally` method
// https://tc39.es/ecma262/#sec-promise.prototype.finally
_export({
target: 'Promise',
proto: true,
real: true,
forced: NON_GENERIC
}, {
'finally': function _finally(onFinally) {
var C = speciesConstructor(this, getBuiltIn('Promise'));
var isFunction = typeof onFinally == 'function';
return this.then(isFunction ? function (x) {
return promiseResolve(C, onFinally()).then(function () {
return x;
});
} : onFinally, isFunction ? function (e) {
return promiseResolve(C, onFinally()).then(function () {
throw e;
});
} : onFinally);
}
}); // patch native Promise.prototype for native async functions
if (!isPure && typeof nativePromiseConstructor == 'function' && !nativePromiseConstructor.prototype['finally']) {
redefine(nativePromiseConstructor.prototype, 'finally', getBuiltIn('Promise').prototype['finally']);
}
var es_promise_finally = {};
'use strict';
var collection = function collection(CONSTRUCTOR_NAME, wrapper, common) {
var IS_MAP = CONSTRUCTOR_NAME.indexOf('Map') !== -1;
var IS_WEAK = CONSTRUCTOR_NAME.indexOf('Weak') !== -1;
var ADDER = IS_MAP ? 'set' : 'add';
var NativeConstructor = global_1[CONSTRUCTOR_NAME];
var NativePrototype = NativeConstructor && NativeConstructor.prototype;
var Constructor = NativeConstructor;
var exported = {};
var fixMethod = function fixMethod(KEY) {
var nativeMethod = NativePrototype[KEY];
redefine(NativePrototype, KEY, KEY == 'add' ? function add(value) {
nativeMethod.call(this, value === 0 ? 0 : value);
return this;
} : KEY == 'delete' ? function (key) {
return IS_WEAK && !isObject(key) ? false : nativeMethod.call(this, key === 0 ? 0 : key);
} : KEY == 'get' ? function get(key) {
return IS_WEAK && !isObject(key) ? undefined : nativeMethod.call(this, key === 0 ? 0 : key);
} : KEY == 'has' ? function has(key) {
return IS_WEAK && !isObject(key) ? false : nativeMethod.call(this, key === 0 ? 0 : key);
} : function set(key, value) {
nativeMethod.call(this, key === 0 ? 0 : key, value);
return this;
});
};
var REPLACE = isForced_1(CONSTRUCTOR_NAME, typeof NativeConstructor != 'function' || !(IS_WEAK || NativePrototype.forEach && !fails(function () {
new NativeConstructor().entries().next();
})));
if (REPLACE) {
// create collection constructor
Constructor = common.getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER);
internalMetadata.REQUIRED = true;
} else if (isForced_1(CONSTRUCTOR_NAME, true)) {
var instance = new Constructor(); // early implementations not supports chaining
var HASNT_CHAINING = instance[ADDER](IS_WEAK ? {} : -0, 1) != instance; // V8 ~ Chromium 40- weak-collections throws on primitives, but should return false
var THROWS_ON_PRIMITIVES = fails(function () {
instance.has(1);
}); // most early implementations doesn't supports iterables, most modern - not close it correctly
// eslint-disable-next-line no-new -- required for testing
var ACCEPT_ITERABLES = checkCorrectnessOfIteration(function (iterable) {
new NativeConstructor(iterable);
}); // for early implementations -0 and +0 not the same
var BUGGY_ZERO = !IS_WEAK && fails(function () {
// V8 ~ Chromium 42- fails only with 5+ elements
var $instance = new NativeConstructor();
var index = 5;
while (index--) {
$instance[ADDER](index, index);
}
return !$instance.has(-0);
});
if (!ACCEPT_ITERABLES) {
Constructor = wrapper(function (dummy, iterable) {
anInstance(dummy, Constructor, CONSTRUCTOR_NAME);
var that = inheritIfRequired(new NativeConstructor(), dummy, Constructor);
if (iterable != undefined) iterate(iterable, that[ADDER], {
that: that,
AS_ENTRIES: IS_MAP
});
return that;
});
Constructor.prototype = NativePrototype;
NativePrototype.constructor = Constructor;
}
if (THROWS_ON_PRIMITIVES || BUGGY_ZERO) {
fixMethod('delete');
fixMethod('has');
IS_MAP && fixMethod('get');
}
if (BUGGY_ZERO || HASNT_CHAINING) fixMethod(ADDER); // weak collections should not contains .clear method
if (IS_WEAK && NativePrototype.clear) delete NativePrototype.clear;
}
exported[CONSTRUCTOR_NAME] = Constructor;
_export({
global: true,
forced: Constructor != NativeConstructor
}, exported);
setToStringTag(Constructor, CONSTRUCTOR_NAME);
if (!IS_WEAK) common.setStrong(Constructor, CONSTRUCTOR_NAME, IS_MAP);
return Constructor;
};
'use strict';
var defineProperty$8 = objectDefineProperty.f;
var fastKey = internalMetadata.fastKey;
var setInternalState$6 = internalState.set;
var internalStateGetterFor = internalState.getterFor;
var collectionStrong = {
getConstructor: function getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER) {
var C = wrapper(function (that, iterable) {
anInstance(that, C, CONSTRUCTOR_NAME);
setInternalState$6(that, {
type: CONSTRUCTOR_NAME,
index: objectCreate(null),
first: undefined,
last: undefined,
size: 0
});
if (!descriptors) that.size = 0;
if (iterable != undefined) iterate(iterable, that[ADDER], {
that: that,
AS_ENTRIES: IS_MAP
});
});
var getInternalState = internalStateGetterFor(CONSTRUCTOR_NAME);
var define = function define(that, key, value) {
var state = getInternalState(that);
var entry = getEntry(that, key);
var previous, index; // change existing entry
if (entry) {
entry.value = value; // create new entry
} else {
state.last = entry = {
index: index = fastKey(key, true),
key: key,
value: value,
previous: previous = state.last,
next: undefined,
removed: false
};
if (!state.first) state.first = entry;
if (previous) previous.next = entry;
if (descriptors) state.size++;else that.size++; // add to index
if (index !== 'F') state.index[index] = entry;
}
return that;
};
var getEntry = function getEntry(that, key) {
var state = getInternalState(that); // fast case
var index = fastKey(key);
var entry;
if (index !== 'F') return state.index[index]; // frozen object case
for (entry = state.first; entry; entry = entry.next) {
if (entry.key == key) return entry;
}
};
redefineAll(C.prototype, {
// 23.1.3.1 Map.prototype.clear()
// 23.2.3.2 Set.prototype.clear()
clear: function clear() {
var that = this;
var state = getInternalState(that);
var data = state.index;
var entry = state.first;
while (entry) {
entry.removed = true;
if (entry.previous) entry.previous = entry.previous.next = undefined;
delete data[entry.index];
entry = entry.next;
}
state.first = state.last = undefined;
if (descriptors) state.size = 0;else that.size = 0;
},
// 23.1.3.3 Map.prototype.delete(key)
// 23.2.3.4 Set.prototype.delete(value)
'delete': function _delete(key) {
var that = this;
var state = getInternalState(that);
var entry = getEntry(that, key);
if (entry) {
var next = entry.next;
var prev = entry.previous;
delete state.index[entry.index];
entry.removed = true;
if (prev) prev.next = next;
if (next) next.previous = prev;
if (state.first == entry) state.first = next;
if (state.last == entry) state.last = prev;
if (descriptors) state.size--;else that.size--;
}
return !!entry;
},
// 23.2.3.6 Set.prototype.forEach(callbackfn, thisArg = undefined)
// 23.1.3.5 Map.prototype.forEach(callbackfn, thisArg = undefined)
forEach: function forEach(callbackfn
/* , that = undefined */
) {
var state = getInternalState(this);
var boundFunction = functionBindContext(callbackfn, arguments.length > 1 ? arguments[1] : undefined, 3);
var entry;
while (entry = entry ? entry.next : state.first) {
boundFunction(entry.value, entry.key, this); // revert to the last existing entry
while (entry && entry.removed) {
entry = entry.previous;
}
}
},
// 23.1.3.7 Map.prototype.has(key)
// 23.2.3.7 Set.prototype.has(value)
has: function has(key) {
return !!getEntry(this, key);
}
});
redefineAll(C.prototype, IS_MAP ? {
// 23.1.3.6 Map.prototype.get(key)
get: function get(key) {
var entry = getEntry(this, key);
return entry && entry.value;
},
// 23.1.3.9 Map.prototype.set(key, value)
set: function set(key, value) {
return define(this, key === 0 ? 0 : key, value);
}
} : {
// 23.2.3.1 Set.prototype.add(value)
add: function add(value) {
return define(this, value = value === 0 ? 0 : value, value);
}
});
if (descriptors) defineProperty$8(C.prototype, 'size', {
get: function get() {
return getInternalState(this).size;
}
});
return C;
},
setStrong: function setStrong(C, CONSTRUCTOR_NAME, IS_MAP) {
var ITERATOR_NAME = CONSTRUCTOR_NAME + ' Iterator';
var getInternalCollectionState = internalStateGetterFor(CONSTRUCTOR_NAME);
var getInternalIteratorState = internalStateGetterFor(ITERATOR_NAME); // add .keys, .values, .entries, [@@iterator]
// 23.1.3.4, 23.1.3.8, 23.1.3.11, 23.1.3.12, 23.2.3.5, 23.2.3.8, 23.2.3.10, 23.2.3.11
defineIterator(C, CONSTRUCTOR_NAME, function (iterated, kind) {
setInternalState$6(this, {
type: ITERATOR_NAME,
target: iterated,
state: getInternalCollectionState(iterated),
kind: kind,
last: undefined
});
}, function () {
var state = getInternalIteratorState(this);
var kind = state.kind;
var entry = state.last; // revert to the last existing entry
while (entry && entry.removed) {
entry = entry.previous;
} // get next entry
if (!state.target || !(state.last = entry = entry ? entry.next : state.state.first)) {
// or finish the iteration
state.target = undefined;
return {
value: undefined,
done: true
};
} // return step by kind
if (kind == 'keys') return {
value: entry.key,
done: false
};
if (kind == 'values') return {
value: entry.value,
done: false
};
return {
value: [entry.key, entry.value],
done: false
};
}, IS_MAP ? 'entries' : 'values', !IS_MAP, true); // add [@@species], 23.1.2.2, 23.2.2.2
setSpecies(CONSTRUCTOR_NAME);
}
};
var collectionStrong_1 = collectionStrong.getConstructor;
var collectionStrong_2 = collectionStrong.setStrong;
'use strict'; // `Map` constructor
// https://tc39.es/ecma262/#sec-map-objects
var es_map = collection('Map', function (init) {
return function Map() {
return init(this, arguments.length ? arguments[0] : undefined);
};
}, collectionStrong);
'use strict'; // `Set` constructor
// https://tc39.es/ecma262/#sec-set-objects
var es_set = collection('Set', function (init) {
return function Set() {
return init(this, arguments.length ? arguments[0] : undefined);
};
}, collectionStrong);
'use strict';
var getWeakData = internalMetadata.getWeakData;
var setInternalState$7 = internalState.set;
var internalStateGetterFor$1 = internalState.getterFor;
var find = arrayIteration.find;
var findIndex = arrayIteration.findIndex;
var id$1 = 0; // fallback for uncaught frozen keys
var uncaughtFrozenStore = function uncaughtFrozenStore(store) {
return store.frozen || (store.frozen = new UncaughtFrozenStore());
};
var UncaughtFrozenStore = function UncaughtFrozenStore() {
this.entries = [];
};
var findUncaughtFrozen = function findUncaughtFrozen(store, key) {
return find(store.entries, function (it) {
return it[0] === key;
});
};
UncaughtFrozenStore.prototype = {
get: function get(key) {
var entry = findUncaughtFrozen(this, key);
if (entry) return entry[1];
},
has: function has(key) {
return !!findUncaughtFrozen(this, key);
},
set: function set(key, value) {
var entry = findUncaughtFrozen(this, key);
if (entry) entry[1] = value;else this.entries.push([key, value]);
},
'delete': function _delete(key) {
var index = findIndex(this.entries, function (it) {
return it[0] === key;
});
if (~index) this.entries.splice(index, 1);
return !!~index;
}
};
var collectionWeak = {
getConstructor: function getConstructor(wrapper, CONSTRUCTOR_NAME, IS_MAP, ADDER) {
var C = wrapper(function (that, iterable) {
anInstance(that, C, CONSTRUCTOR_NAME);
setInternalState$7(that, {
type: CONSTRUCTOR_NAME,
id: id$1++,
frozen: undefined
});
if (iterable != undefined) iterate(iterable, that[ADDER], {
that: that,
AS_ENTRIES: IS_MAP
});
});
var getInternalState = internalStateGetterFor$1(CONSTRUCTOR_NAME);
var define = function define(that, key, value) {
var state = getInternalState(that);
var data = getWeakData(anObject(key), true);
if (data === true) uncaughtFrozenStore(state).set(key, value);else data[state.id] = value;
return that;
};
redefineAll(C.prototype, {
// 23.3.3.2 WeakMap.prototype.delete(key)
// 23.4.3.3 WeakSet.prototype.delete(value)
'delete': function _delete(key) {
var state = getInternalState(this);
if (!isObject(key)) return false;
var data = getWeakData(key);
if (data === true) return uncaughtFrozenStore(state)['delete'](key);
return data && has(data, state.id) && delete data[state.id];
},
// 23.3.3.4 WeakMap.prototype.has(key)
// 23.4.3.4 WeakSet.prototype.has(value)
has: function has$1(key) {
var state = getInternalState(this);
if (!isObject(key)) return false;
var data = getWeakData(key);
if (data === true) return uncaughtFrozenStore(state).has(key);
return data && has(data, state.id);
}
});
redefineAll(C.prototype, IS_MAP ? {
// 23.3.3.3 WeakMap.prototype.get(key)
get: function get(key) {
var state = getInternalState(this);
if (isObject(key)) {
var data = getWeakData(key);
if (data === true) return uncaughtFrozenStore(state).get(key);
return data ? data[state.id] : undefined;
}
},
// 23.3.3.5 WeakMap.prototype.set(key, value)
set: function set(key, value) {
return define(this, key, value);
}
} : {
// 23.4.3.1 WeakSet.prototype.add(value)
add: function add(value) {
return define(this, value, true);
}
});
return C;
}
};
var collectionWeak_1 = collectionWeak.getConstructor;
var es_weakMap = createCommonjsModule(function (module) {
'use strict';
var enforceIternalState = internalState.enforce;
var IS_IE11 = !global_1.ActiveXObject && 'ActiveXObject' in global_1;
var isExtensible = Object.isExtensible;
var InternalWeakMap;
var wrapper = function wrapper(init) {
return function WeakMap() {
return init(this, arguments.length ? arguments[0] : undefined);
};
}; // `WeakMap` constructor
// https://tc39.es/ecma262/#sec-weakmap-constructor
var $WeakMap = module.exports = collection('WeakMap', wrapper, collectionWeak); // IE11 WeakMap frozen keys fix
// We can't use feature detection because it crash some old IE builds
// https://github.com/zloirock/core-js/issues/485
if (nativeWeakMap && IS_IE11) {
InternalWeakMap = collectionWeak.getConstructor(wrapper, 'WeakMap', true);
internalMetadata.REQUIRED = true;
var WeakMapPrototype = $WeakMap.prototype;
var nativeDelete = WeakMapPrototype['delete'];
var nativeHas = WeakMapPrototype.has;
var nativeGet = WeakMapPrototype.get;
var nativeSet = WeakMapPrototype.set;
redefineAll(WeakMapPrototype, {
'delete': function _delete(key) {
if (isObject(key) && !isExtensible(key)) {
var state = enforceIternalState(this);
if (!state.frozen) state.frozen = new InternalWeakMap();
return nativeDelete.call(this, key) || state.frozen['delete'](key);
}
return nativeDelete.call(this, key);
},
has: function has(key) {
if (isObject(key) && !isExtensible(key)) {
var state = enforceIternalState(this);
if (!state.frozen) state.frozen = new InternalWeakMap();
return nativeHas.call(this, key) || state.frozen.has(key);
}
return nativeHas.call(this, key);
},
get: function get(key) {
if (isObject(key) && !isExtensible(key)) {
var state = enforceIternalState(this);
if (!state.frozen) state.frozen = new InternalWeakMap();
return nativeHas.call(this, key) ? nativeGet.call(this, key) : state.frozen.get(key);
}
return nativeGet.call(this, key);
},
set: function set(key, value) {
if (isObject(key) && !isExtensible(key)) {
var state = enforceIternalState(this);
if (!state.frozen) state.frozen = new InternalWeakMap();
nativeHas.call(this, key) ? nativeSet.call(this, key, value) : state.frozen.set(key, value);
} else nativeSet.call(this, key, value);
return this;
}
});
}
});
'use strict'; // `WeakSet` constructor
// https://tc39.es/ecma262/#sec-weakset-constructor
collection('WeakSet', function (init) {
return function WeakSet() {
return init(this, arguments.length ? arguments[0] : undefined);
};
}, collectionWeak);
var es_weakSet = {};
var arrayBufferNative = typeof ArrayBuffer !== 'undefined' && typeof DataView !== 'undefined';
// https://tc39.es/ecma262/#sec-toindex
var toIndex = function toIndex(it) {
if (it === undefined) return 0;
var number = toInteger(it);
var length = toLength(number);
if (number !== length) throw RangeError('Wrong length or index');
return length;
};
// IEEE754 conversions based on https://github.com/feross/ieee754
var abs$7 = Math.abs;
var pow$3 = Math.pow;
var floor$6 = Math.floor;
var log$8 = Math.log;
var LN2$2 = Math.LN2;
var pack = function pack(number, mantissaLength, bytes) {
var buffer = new Array(bytes);
var exponentLength = bytes * 8 - mantissaLength - 1;
var eMax = (1 << exponentLength) - 1;
var eBias = eMax >> 1;
var rt = mantissaLength === 23 ? pow$3(2, -24) - pow$3(2, -77) : 0;
var sign = number < 0 || number === 0 && 1 / number < 0 ? 1 : 0;
var index = 0;
var exponent, mantissa, c;
number = abs$7(number); // eslint-disable-next-line no-self-compare -- NaN check
if (number != number || number === Infinity) {
// eslint-disable-next-line no-self-compare -- NaN check
mantissa = number != number ? 1 : 0;
exponent = eMax;
} else {
exponent = floor$6(log$8(number) / LN2$2);
if (number * (c = pow$3(2, -exponent)) < 1) {
exponent--;
c *= 2;
}
if (exponent + eBias >= 1) {
number += rt / c;
} else {
number += rt * pow$3(2, 1 - eBias);
}
if (number * c >= 2) {
exponent++;
c /= 2;
}
if (exponent + eBias >= eMax) {
mantissa = 0;
exponent = eMax;
} else if (exponent + eBias >= 1) {
mantissa = (number * c - 1) * pow$3(2, mantissaLength);
exponent = exponent + eBias;
} else {
mantissa = number * pow$3(2, eBias - 1) * pow$3(2, mantissaLength);
exponent = 0;
}
}
for (; mantissaLength >= 8; buffer[index++] = mantissa & 255, mantissa /= 256, mantissaLength -= 8) {
;
}
exponent = exponent << mantissaLength | mantissa;
exponentLength += mantissaLength;
for (; exponentLength > 0; buffer[index++] = exponent & 255, exponent /= 256, exponentLength -= 8) {
;
}
buffer[--index] |= sign * 128;
return buffer;
};
var unpack = function unpack(buffer, mantissaLength) {
var bytes = buffer.length;
var exponentLength = bytes * 8 - mantissaLength - 1;
var eMax = (1 << exponentLength) - 1;
var eBias = eMax >> 1;
var nBits = exponentLength - 7;
var index = bytes - 1;
var sign = buffer[index--];
var exponent = sign & 127;
var mantissa;
sign >>= 7;
for (; nBits > 0; exponent = exponent * 256 + buffer[index], index--, nBits -= 8) {
;
}
mantissa = exponent & (1 << -nBits) - 1;
exponent >>= -nBits;
nBits += mantissaLength;
for (; nBits > 0; mantissa = mantissa * 256 + buffer[index], index--, nBits -= 8) {
;
}
if (exponent === 0) {
exponent = 1 - eBias;
} else if (exponent === eMax) {
return mantissa ? NaN : sign ? -Infinity : Infinity;
} else {
mantissa = mantissa + pow$3(2, mantissaLength);
exponent = exponent - eBias;
}
return (sign ? -1 : 1) * mantissa * pow$3(2, exponent - mantissaLength);
};
var ieee754 = {
pack: pack,
unpack: unpack
};
var ieee754_1 = ieee754.pack;
var ieee754_2 = ieee754.unpack;
'use strict';
var getOwnPropertyNames$2 = objectGetOwnPropertyNames.f;
var defineProperty$9 = objectDefineProperty.f;
var getInternalState$6 = internalState.get;
var setInternalState$8 = internalState.set;
var ARRAY_BUFFER = 'ArrayBuffer';
var DATA_VIEW = 'DataView';
var PROTOTYPE$2 = 'prototype';
var WRONG_LENGTH = 'Wrong length';
var WRONG_INDEX = 'Wrong index';
var NativeArrayBuffer = global_1[ARRAY_BUFFER];
var $ArrayBuffer = NativeArrayBuffer;
var $DataView = global_1[DATA_VIEW];
var $DataViewPrototype = $DataView && $DataView[PROTOTYPE$2];
var ObjectPrototype$2 = Object.prototype;
var RangeError$1 = global_1.RangeError;
var packIEEE754 = ieee754.pack;
var unpackIEEE754 = ieee754.unpack;
var packInt8 = function packInt8(number) {
return [number & 0xFF];
};
var packInt16 = function packInt16(number) {
return [number & 0xFF, number >> 8 & 0xFF];
};
var packInt32 = function packInt32(number) {
return [number & 0xFF, number >> 8 & 0xFF, number >> 16 & 0xFF, number >> 24 & 0xFF];
};
var unpackInt32 = function unpackInt32(buffer) {
return buffer[3] << 24 | buffer[2] << 16 | buffer[1] << 8 | buffer[0];
};
var packFloat32 = function packFloat32(number) {
return packIEEE754(number, 23, 4);
};
var packFloat64 = function packFloat64(number) {
return packIEEE754(number, 52, 8);
};
var addGetter = function addGetter(Constructor, key) {
defineProperty$9(Constructor[PROTOTYPE$2], key, {
get: function get() {
return getInternalState$6(this)[key];
}
});
};
var get$1 = function get(view, count, index, isLittleEndian) {
var intIndex = toIndex(index);
var store = getInternalState$6(view);
if (intIndex + count > store.byteLength) throw RangeError$1(WRONG_INDEX);
var bytes = getInternalState$6(store.buffer).bytes;
var start = intIndex + store.byteOffset;
var pack = bytes.slice(start, start + count);
return isLittleEndian ? pack : pack.reverse();
};
var set$2 = function set(view, count, index, conversion, value, isLittleEndian) {
var intIndex = toIndex(index);
var store = getInternalState$6(view);
if (intIndex + count > store.byteLength) throw RangeError$1(WRONG_INDEX);
var bytes = getInternalState$6(store.buffer).bytes;
var start = intIndex + store.byteOffset;
var pack = conversion(+value);
for (var i = 0; i < count; i++) {
bytes[start + i] = pack[isLittleEndian ? i : count - i - 1];
}
};
if (!arrayBufferNative) {
$ArrayBuffer = function ArrayBuffer(length) {
anInstance(this, $ArrayBuffer, ARRAY_BUFFER);
var byteLength = toIndex(length);
setInternalState$8(this, {
bytes: arrayFill.call(new Array(byteLength), 0),
byteLength: byteLength
});
if (!descriptors) this.byteLength = byteLength;
};
$DataView = function DataView(buffer, byteOffset, byteLength) {
anInstance(this, $DataView, DATA_VIEW);
anInstance(buffer, $ArrayBuffer, DATA_VIEW);
var bufferLength = getInternalState$6(buffer).byteLength;
var offset = toInteger(byteOffset);
if (offset < 0 || offset > bufferLength) throw RangeError$1('Wrong offset');
byteLength = byteLength === undefined ? bufferLength - offset : toLength(byteLength);
if (offset + byteLength > bufferLength) throw RangeError$1(WRONG_LENGTH);
setInternalState$8(this, {
buffer: buffer,
byteLength: byteLength,
byteOffset: offset
});
if (!descriptors) {
this.buffer = buffer;
this.byteLength = byteLength;
this.byteOffset = offset;
}
};
if (descriptors) {
addGetter($ArrayBuffer, 'byteLength');
addGetter($DataView, 'buffer');
addGetter($DataView, 'byteLength');
addGetter($DataView, 'byteOffset');
}
redefineAll($DataView[PROTOTYPE$2], {
getInt8: function getInt8(byteOffset) {
return get$1(this, 1, byteOffset)[0] << 24 >> 24;
},
getUint8: function getUint8(byteOffset) {
return get$1(this, 1, byteOffset)[0];
},
getInt16: function getInt16(byteOffset
/* , littleEndian */
) {
var bytes = get$1(this, 2, byteOffset, arguments.length > 1 ? arguments[1] : undefined);
return (bytes[1] << 8 | bytes[0]) << 16 >> 16;
},
getUint16: function getUint16(byteOffset
/* , littleEndian */
) {
var bytes = get$1(this, 2, byteOffset, arguments.length > 1 ? arguments[1] : undefined);
return bytes[1] << 8 | bytes[0];
},
getInt32: function getInt32(byteOffset
/* , littleEndian */
) {
return unpackInt32(get$1(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined));
},
getUint32: function getUint32(byteOffset
/* , littleEndian */
) {
return unpackInt32(get$1(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined)) >>> 0;
},
getFloat32: function getFloat32(byteOffset
/* , littleEndian */
) {
return unpackIEEE754(get$1(this, 4, byteOffset, arguments.length > 1 ? arguments[1] : undefined), 23);
},
getFloat64: function getFloat64(byteOffset
/* , littleEndian */
) {
return unpackIEEE754(get$1(this, 8, byteOffset, arguments.length > 1 ? arguments[1] : undefined), 52);
},
setInt8: function setInt8(byteOffset, value) {
set$2(this, 1, byteOffset, packInt8, value);
},
setUint8: function setUint8(byteOffset, value) {
set$2(this, 1, byteOffset, packInt8, value);
},
setInt16: function setInt16(byteOffset, value
/* , littleEndian */
) {
set$2(this, 2, byteOffset, packInt16, value, arguments.length > 2 ? arguments[2] : undefined);
},
setUint16: function setUint16(byteOffset, value
/* , littleEndian */
) {
set$2(this, 2, byteOffset, packInt16, value, arguments.length > 2 ? arguments[2] : undefined);
},
setInt32: function setInt32(byteOffset, value
/* , littleEndian */
) {
set$2(this, 4, byteOffset, packInt32, value, arguments.length > 2 ? arguments[2] : undefined);
},
setUint32: function setUint32(byteOffset, value
/* , littleEndian */
) {
set$2(this, 4, byteOffset, packInt32, value, arguments.length > 2 ? arguments[2] : undefined);
},
setFloat32: function setFloat32(byteOffset, value
/* , littleEndian */
) {
set$2(this, 4, byteOffset, packFloat32, value, arguments.length > 2 ? arguments[2] : undefined);
},
setFloat64: function setFloat64(byteOffset, value
/* , littleEndian */
) {
set$2(this, 8, byteOffset, packFloat64, value, arguments.length > 2 ? arguments[2] : undefined);
}
});
} else {
/* eslint-disable no-new -- required for testing */
if (!fails(function () {
NativeArrayBuffer(1);
}) || !fails(function () {
new NativeArrayBuffer(-1);
}) || fails(function () {
new NativeArrayBuffer();
new NativeArrayBuffer(1.5);
new NativeArrayBuffer(NaN);
return NativeArrayBuffer.name != ARRAY_BUFFER;
})) {
/* eslint-enable no-new -- required for testing */
$ArrayBuffer = function ArrayBuffer(length) {
anInstance(this, $ArrayBuffer);
return new NativeArrayBuffer(toIndex(length));
};
var ArrayBufferPrototype = $ArrayBuffer[PROTOTYPE$2] = NativeArrayBuffer[PROTOTYPE$2];
for (var keys$3 = getOwnPropertyNames$2(NativeArrayBuffer), j$1 = 0, key$1; keys$3.length > j$1;) {
if (!((key$1 = keys$3[j$1++]) in $ArrayBuffer)) {
createNonEnumerableProperty($ArrayBuffer, key$1, NativeArrayBuffer[key$1]);
}
}
ArrayBufferPrototype.constructor = $ArrayBuffer;
} // WebKit bug - the same parent prototype for typed arrays and data view
if (objectSetPrototypeOf && objectGetPrototypeOf($DataViewPrototype) !== ObjectPrototype$2) {
objectSetPrototypeOf($DataViewPrototype, ObjectPrototype$2);
} // iOS Safari 7.x bug
var testView = new $DataView(new $ArrayBuffer(2));
var nativeSetInt8 = $DataViewPrototype.setInt8;
testView.setInt8(0, 2147483648);
testView.setInt8(1, 2147483649);
if (testView.getInt8(0) || !testView.getInt8(1)) redefineAll($DataViewPrototype, {
setInt8: function setInt8(byteOffset, value) {
nativeSetInt8.call(this, byteOffset, value << 24 >> 24);
},
setUint8: function setUint8(byteOffset, value) {
nativeSetInt8.call(this, byteOffset, value << 24 >> 24);
}
}, {
unsafe: true
});
}
setToStringTag($ArrayBuffer, ARRAY_BUFFER);
setToStringTag($DataView, DATA_VIEW);
var arrayBuffer = {
ArrayBuffer: $ArrayBuffer,
DataView: $DataView
};
'use strict';
var ARRAY_BUFFER$1 = 'ArrayBuffer';
var ArrayBuffer$1 = arrayBuffer[ARRAY_BUFFER$1];
var NativeArrayBuffer$1 = global_1[ARRAY_BUFFER$1]; // `ArrayBuffer` constructor
// https://tc39.es/ecma262/#sec-arraybuffer-constructor
_export({
global: true,
forced: NativeArrayBuffer$1 !== ArrayBuffer$1
}, {
ArrayBuffer: ArrayBuffer$1
});
setSpecies(ARRAY_BUFFER$1);
var es_arrayBuffer_constructor = {};
'use strict';
var defineProperty$a = objectDefineProperty.f;
var Int8Array$1 = global_1.Int8Array;
var Int8ArrayPrototype = Int8Array$1 && Int8Array$1.prototype;
var Uint8ClampedArray$1 = global_1.Uint8ClampedArray;
var Uint8ClampedArrayPrototype = Uint8ClampedArray$1 && Uint8ClampedArray$1.prototype;
var TypedArray = Int8Array$1 && objectGetPrototypeOf(Int8Array$1);
var TypedArrayPrototype = Int8ArrayPrototype && objectGetPrototypeOf(Int8ArrayPrototype);
var ObjectPrototype$3 = Object.prototype;
var isPrototypeOf = ObjectPrototype$3.isPrototypeOf;
var TO_STRING_TAG$3 = wellKnownSymbol('toStringTag');
var TYPED_ARRAY_TAG = uid('TYPED_ARRAY_TAG'); // Fixing native typed arrays in Opera Presto crashes the browser, see #595
var NATIVE_ARRAY_BUFFER_VIEWS = arrayBufferNative && !!objectSetPrototypeOf && classof(global_1.opera) !== 'Opera';
var TYPED_ARRAY_TAG_REQIRED = false;
var NAME$1;
var TypedArrayConstructorsList = {
Int8Array: 1,
Uint8Array: 1,
Uint8ClampedArray: 1,
Int16Array: 2,
Uint16Array: 2,
Int32Array: 4,
Uint32Array: 4,
Float32Array: 4,
Float64Array: 8
};
var BigIntArrayConstructorsList = {
BigInt64Array: 8,
BigUint64Array: 8
};
var isView = function isView(it) {
if (!isObject(it)) return false;
var klass = classof(it);
return klass === 'DataView' || has(TypedArrayConstructorsList, klass) || has(BigIntArrayConstructorsList, klass);
};
var isTypedArray = function isTypedArray(it) {
if (!isObject(it)) return false;
var klass = classof(it);
return has(TypedArrayConstructorsList, klass) || has(BigIntArrayConstructorsList, klass);
};
var aTypedArray = function aTypedArray(it) {
if (isTypedArray(it)) return it;
throw TypeError('Target is not a typed array');
};
var aTypedArrayConstructor = function aTypedArrayConstructor(C) {
if (objectSetPrototypeOf) {
if (isPrototypeOf.call(TypedArray, C)) return C;
} else for (var ARRAY in TypedArrayConstructorsList) {
if (has(TypedArrayConstructorsList, NAME$1)) {
var TypedArrayConstructor = global_1[ARRAY];
if (TypedArrayConstructor && (C === TypedArrayConstructor || isPrototypeOf.call(TypedArrayConstructor, C))) {
return C;
}
}
}
throw TypeError('Target is not a typed array constructor');
};
var exportTypedArrayMethod = function exportTypedArrayMethod(KEY, property, forced) {
if (!descriptors) return;
if (forced) for (var ARRAY in TypedArrayConstructorsList) {
var TypedArrayConstructor = global_1[ARRAY];
if (TypedArrayConstructor && has(TypedArrayConstructor.prototype, KEY)) {
delete TypedArrayConstructor.prototype[KEY];
}
}
if (!TypedArrayPrototype[KEY] || forced) {
redefine(TypedArrayPrototype, KEY, forced ? property : NATIVE_ARRAY_BUFFER_VIEWS && Int8ArrayPrototype[KEY] || property);
}
};
var exportTypedArrayStaticMethod = function exportTypedArrayStaticMethod(KEY, property, forced) {
var ARRAY, TypedArrayConstructor;
if (!descriptors) return;
if (objectSetPrototypeOf) {
if (forced) for (ARRAY in TypedArrayConstructorsList) {
TypedArrayConstructor = global_1[ARRAY];
if (TypedArrayConstructor && has(TypedArrayConstructor, KEY)) {
delete TypedArrayConstructor[KEY];
}
}
if (!TypedArray[KEY] || forced) {
// V8 ~ Chrome 49-50 `%TypedArray%` methods are non-writable non-configurable
try {
return redefine(TypedArray, KEY, forced ? property : NATIVE_ARRAY_BUFFER_VIEWS && Int8Array$1[KEY] || property);
} catch (error) {
/* empty */
}
} else return;
}
for (ARRAY in TypedArrayConstructorsList) {
TypedArrayConstructor = global_1[ARRAY];
if (TypedArrayConstructor && (!TypedArrayConstructor[KEY] || forced)) {
redefine(TypedArrayConstructor, KEY, property);
}
}
};
for (NAME$1 in TypedArrayConstructorsList) {
if (!global_1[NAME$1]) NATIVE_ARRAY_BUFFER_VIEWS = false;
} // WebKit bug - typed arrays constructors prototype is Object.prototype
if (!NATIVE_ARRAY_BUFFER_VIEWS || typeof TypedArray != 'function' || TypedArray === Function.prototype) {
// eslint-disable-next-line no-shadow -- safe
TypedArray = function TypedArray() {
throw TypeError('Incorrect invocation');
};
if (NATIVE_ARRAY_BUFFER_VIEWS) for (NAME$1 in TypedArrayConstructorsList) {
if (global_1[NAME$1]) objectSetPrototypeOf(global_1[NAME$1], TypedArray);
}
}
if (!NATIVE_ARRAY_BUFFER_VIEWS || !TypedArrayPrototype || TypedArrayPrototype === ObjectPrototype$3) {
TypedArrayPrototype = TypedArray.prototype;
if (NATIVE_ARRAY_BUFFER_VIEWS) for (NAME$1 in TypedArrayConstructorsList) {
if (global_1[NAME$1]) objectSetPrototypeOf(global_1[NAME$1].prototype, TypedArrayPrototype);
}
} // WebKit bug - one more object in Uint8ClampedArray prototype chain
if (NATIVE_ARRAY_BUFFER_VIEWS && objectGetPrototypeOf(Uint8ClampedArrayPrototype) !== TypedArrayPrototype) {
objectSetPrototypeOf(Uint8ClampedArrayPrototype, TypedArrayPrototype);
}
if (descriptors && !has(TypedArrayPrototype, TO_STRING_TAG$3)) {
TYPED_ARRAY_TAG_REQIRED = true;
defineProperty$a(TypedArrayPrototype, TO_STRING_TAG$3, {
get: function get() {
return isObject(this) ? this[TYPED_ARRAY_TAG] : undefined;
}
});
for (NAME$1 in TypedArrayConstructorsList) {
if (global_1[NAME$1]) {
createNonEnumerableProperty(global_1[NAME$1], TYPED_ARRAY_TAG, NAME$1);
}
}
}
var arrayBufferViewCore = {
NATIVE_ARRAY_BUFFER_VIEWS: NATIVE_ARRAY_BUFFER_VIEWS,
TYPED_ARRAY_TAG: TYPED_ARRAY_TAG_REQIRED && TYPED_ARRAY_TAG,
aTypedArray: aTypedArray,
aTypedArrayConstructor: aTypedArrayConstructor,
exportTypedArrayMethod: exportTypedArrayMethod,
exportTypedArrayStaticMethod: exportTypedArrayStaticMethod,
isView: isView,
isTypedArray: isTypedArray,
TypedArray: TypedArray,
TypedArrayPrototype: TypedArrayPrototype
};
var arrayBufferViewCore_1 = arrayBufferViewCore.NATIVE_ARRAY_BUFFER_VIEWS;
var arrayBufferViewCore_2 = arrayBufferViewCore.TYPED_ARRAY_TAG;
var arrayBufferViewCore_3 = arrayBufferViewCore.aTypedArray;
var arrayBufferViewCore_4 = arrayBufferViewCore.aTypedArrayConstructor;
var arrayBufferViewCore_5 = arrayBufferViewCore.exportTypedArrayMethod;
var arrayBufferViewCore_6 = arrayBufferViewCore.exportTypedArrayStaticMethod;
var arrayBufferViewCore_7 = arrayBufferViewCore.isView;
var arrayBufferViewCore_8 = arrayBufferViewCore.isTypedArray;
var arrayBufferViewCore_9 = arrayBufferViewCore.TypedArray;
var arrayBufferViewCore_10 = arrayBufferViewCore.TypedArrayPrototype;
var NATIVE_ARRAY_BUFFER_VIEWS$1 = arrayBufferViewCore.NATIVE_ARRAY_BUFFER_VIEWS; // `ArrayBuffer.isView` method
// https://tc39.es/ecma262/#sec-arraybuffer.isview
_export({
target: 'ArrayBuffer',
stat: true,
forced: !NATIVE_ARRAY_BUFFER_VIEWS$1
}, {
isView: arrayBufferViewCore.isView
});
var es_arrayBuffer_isView = {};
'use strict';
var ArrayBuffer$2 = arrayBuffer.ArrayBuffer;
var DataView$1 = arrayBuffer.DataView;
var nativeArrayBufferSlice = ArrayBuffer$2.prototype.slice;
var INCORRECT_SLICE = fails(function () {
return !new ArrayBuffer$2(2).slice(1, undefined).byteLength;
}); // `ArrayBuffer.prototype.slice` method
// https://tc39.es/ecma262/#sec-arraybuffer.prototype.slice
_export({
target: 'ArrayBuffer',
proto: true,
unsafe: true,
forced: INCORRECT_SLICE
}, {
slice: function slice(start, end) {
if (nativeArrayBufferSlice !== undefined && end === undefined) {
return nativeArrayBufferSlice.call(anObject(this), start); // FF fix
}
var length = anObject(this).byteLength;
var first = toAbsoluteIndex(start, length);
var fin = toAbsoluteIndex(end === undefined ? length : end, length);
var result = new (speciesConstructor(this, ArrayBuffer$2))(toLength(fin - first));
var viewSource = new DataView$1(this);
var viewTarget = new DataView$1(result);
var index = 0;
while (first < fin) {
viewTarget.setUint8(index++, viewSource.getUint8(first++));
}
return result;
}
});
var es_arrayBuffer_slice = {};
// https://tc39.es/ecma262/#sec-dataview-constructor
_export({
global: true,
forced: !arrayBufferNative
}, {
DataView: arrayBuffer.DataView
});
var es_dataView = {};
/* eslint-disable no-new -- required for testing */
var NATIVE_ARRAY_BUFFER_VIEWS$2 = arrayBufferViewCore.NATIVE_ARRAY_BUFFER_VIEWS;
var ArrayBuffer$3 = global_1.ArrayBuffer;
var Int8Array$2 = global_1.Int8Array;
var typedArrayConstructorsRequireWrappers = !NATIVE_ARRAY_BUFFER_VIEWS$2 || !fails(function () {
Int8Array$2(1);
}) || !fails(function () {
new Int8Array$2(-1);
}) || !checkCorrectnessOfIteration(function (iterable) {
new Int8Array$2();
new Int8Array$2(null);
new Int8Array$2(1.5);
new Int8Array$2(iterable);
}, true) || fails(function () {
// Safari (11+) bug - a reason why even Safari 13 should load a typed array polyfill
return new Int8Array$2(new ArrayBuffer$3(2), 1, undefined).length !== 1;
});
var toPositiveInteger = function toPositiveInteger(it) {
var result = toInteger(it);
if (result < 0) throw RangeError("The argument can't be less than 0");
return result;
};
var toOffset = function toOffset(it, BYTES) {
var offset = toPositiveInteger(it);
if (offset % BYTES) throw RangeError('Wrong offset');
return offset;
};
var aTypedArrayConstructor$1 = arrayBufferViewCore.aTypedArrayConstructor;
var typedArrayFrom = function from(source
/* , mapfn, thisArg */
) {
var O = toObject(source);
var argumentsLength = arguments.length;
var mapfn = argumentsLength > 1 ? arguments[1] : undefined;
var mapping = mapfn !== undefined;
var iteratorMethod = getIteratorMethod(O);
var i, length, result, step, iterator, next;
if (iteratorMethod != undefined && !isArrayIteratorMethod(iteratorMethod)) {
iterator = iteratorMethod.call(O);
next = iterator.next;
O = [];
while (!(step = next.call(iterator)).done) {
O.push(step.value);
}
}
if (mapping && argumentsLength > 2) {
mapfn = functionBindContext(mapfn, arguments[2], 2);
}
length = toLength(O.length);
result = new (aTypedArrayConstructor$1(this))(length);
for (i = 0; length > i; i++) {
result[i] = mapping ? mapfn(O[i], i) : O[i];
}
return result;
};
var typedArrayConstructor = createCommonjsModule(function (module) {
'use strict';
var getOwnPropertyNames = objectGetOwnPropertyNames.f;
var forEach = arrayIteration.forEach;
var getInternalState = internalState.get;
var setInternalState = internalState.set;
var nativeDefineProperty = objectDefineProperty.f;
var nativeGetOwnPropertyDescriptor = objectGetOwnPropertyDescriptor.f;
var round = Math.round;
var RangeError = global_1.RangeError;
var ArrayBuffer = arrayBuffer.ArrayBuffer;
var DataView = arrayBuffer.DataView;
var NATIVE_ARRAY_BUFFER_VIEWS = arrayBufferViewCore.NATIVE_ARRAY_BUFFER_VIEWS;
var TYPED_ARRAY_TAG = arrayBufferViewCore.TYPED_ARRAY_TAG;
var TypedArray = arrayBufferViewCore.TypedArray;
var TypedArrayPrototype = arrayBufferViewCore.TypedArrayPrototype;
var aTypedArrayConstructor = arrayBufferViewCore.aTypedArrayConstructor;
var isTypedArray = arrayBufferViewCore.isTypedArray;
var BYTES_PER_ELEMENT = 'BYTES_PER_ELEMENT';
var WRONG_LENGTH = 'Wrong length';
var fromList = function fromList(C, list) {
var index = 0;
var length = list.length;
var result = new (aTypedArrayConstructor(C))(length);
while (length > index) {
result[index] = list[index++];
}
return result;
};
var addGetter = function addGetter(it, key) {
nativeDefineProperty(it, key, {
get: function get() {
return getInternalState(this)[key];
}
});
};
var isArrayBuffer = function isArrayBuffer(it) {
var klass;
return it instanceof ArrayBuffer || (klass = classof(it)) == 'ArrayBuffer' || klass == 'SharedArrayBuffer';
};
var isTypedArrayIndex = function isTypedArrayIndex(target, key) {
return isTypedArray(target) && typeof key != 'symbol' && key in target && String(+key) == String(key);
};
var wrappedGetOwnPropertyDescriptor = function getOwnPropertyDescriptor(target, key) {
return isTypedArrayIndex(target, key = toPrimitive(key, true)) ? createPropertyDescriptor(2, target[key]) : nativeGetOwnPropertyDescriptor(target, key);
};
var wrappedDefineProperty = function defineProperty(target, key, descriptor) {
if (isTypedArrayIndex(target, key = toPrimitive(key, true)) && isObject(descriptor) && has(descriptor, 'value') && !has(descriptor, 'get') && !has(descriptor, 'set') // TODO: add validation descriptor w/o calling accessors
&& !descriptor.configurable && (!has(descriptor, 'writable') || descriptor.writable) && (!has(descriptor, 'enumerable') || descriptor.enumerable)) {
target[key] = descriptor.value;
return target;
}
return nativeDefineProperty(target, key, descriptor);
};
if (descriptors) {
if (!NATIVE_ARRAY_BUFFER_VIEWS) {
objectGetOwnPropertyDescriptor.f = wrappedGetOwnPropertyDescriptor;
objectDefineProperty.f = wrappedDefineProperty;
addGetter(TypedArrayPrototype, 'buffer');
addGetter(TypedArrayPrototype, 'byteOffset');
addGetter(TypedArrayPrototype, 'byteLength');
addGetter(TypedArrayPrototype, 'length');
}
_export({
target: 'Object',
stat: true,
forced: !NATIVE_ARRAY_BUFFER_VIEWS
}, {
getOwnPropertyDescriptor: wrappedGetOwnPropertyDescriptor,
defineProperty: wrappedDefineProperty
});
module.exports = function (TYPE, wrapper, CLAMPED) {
var BYTES = TYPE.match(/\d+$/)[0] / 8;
var CONSTRUCTOR_NAME = TYPE + (CLAMPED ? 'Clamped' : '') + 'Array';
var GETTER = 'get' + TYPE;
var SETTER = 'set' + TYPE;
var NativeTypedArrayConstructor = global_1[CONSTRUCTOR_NAME];
var TypedArrayConstructor = NativeTypedArrayConstructor;
var TypedArrayConstructorPrototype = TypedArrayConstructor && TypedArrayConstructor.prototype;
var exported = {};
var getter = function getter(that, index) {
var data = getInternalState(that);
return data.view[GETTER](index * BYTES + data.byteOffset, true);
};
var setter = function setter(that, index, value) {
var data = getInternalState(that);
if (CLAMPED) value = (value = round(value)) < 0 ? 0 : value > 0xFF ? 0xFF : value & 0xFF;
data.view[SETTER](index * BYTES + data.byteOffset, value, true);
};
var addElement = function addElement(that, index) {
nativeDefineProperty(that, index, {
get: function get() {
return getter(this, index);
},
set: function set(value) {
return setter(this, index, value);
},
enumerable: true
});
};
if (!NATIVE_ARRAY_BUFFER_VIEWS) {
TypedArrayConstructor = wrapper(function (that, data, offset, $length) {
anInstance(that, TypedArrayConstructor, CONSTRUCTOR_NAME);
var index = 0;
var byteOffset = 0;
var buffer, byteLength, length;
if (!isObject(data)) {
length = toIndex(data);
byteLength = length * BYTES;
buffer = new ArrayBuffer(byteLength);
} else if (isArrayBuffer(data)) {
buffer = data;
byteOffset = toOffset(offset, BYTES);
var $len = data.byteLength;
if ($length === undefined) {
if ($len % BYTES) throw RangeError(WRONG_LENGTH);
byteLength = $len - byteOffset;
if (byteLength < 0) throw RangeError(WRONG_LENGTH);
} else {
byteLength = toLength($length) * BYTES;
if (byteLength + byteOffset > $len) throw RangeError(WRONG_LENGTH);
}
length = byteLength / BYTES;
} else if (isTypedArray(data)) {
return fromList(TypedArrayConstructor, data);
} else {
return typedArrayFrom.call(TypedArrayConstructor, data);
}
setInternalState(that, {
buffer: buffer,
byteOffset: byteOffset,
byteLength: byteLength,
length: length,
view: new DataView(buffer)
});
while (index < length) {
addElement(that, index++);
}
});
if (objectSetPrototypeOf) objectSetPrototypeOf(TypedArrayConstructor, TypedArray);
TypedArrayConstructorPrototype = TypedArrayConstructor.prototype = objectCreate(TypedArrayPrototype);
} else if (typedArrayConstructorsRequireWrappers) {
TypedArrayConstructor = wrapper(function (dummy, data, typedArrayOffset, $length) {
anInstance(dummy, TypedArrayConstructor, CONSTRUCTOR_NAME);
return inheritIfRequired(function () {
if (!isObject(data)) return new NativeTypedArrayConstructor(toIndex(data));
if (isArrayBuffer(data)) return $length !== undefined ? new NativeTypedArrayConstructor(data, toOffset(typedArrayOffset, BYTES), $length) : typedArrayOffset !== undefined ? new NativeTypedArrayConstructor(data, toOffset(typedArrayOffset, BYTES)) : new NativeTypedArrayConstructor(data);
if (isTypedArray(data)) return fromList(TypedArrayConstructor, data);
return typedArrayFrom.call(TypedArrayConstructor, data);
}(), dummy, TypedArrayConstructor);
});
if (objectSetPrototypeOf) objectSetPrototypeOf(TypedArrayConstructor, TypedArray);
forEach(getOwnPropertyNames(NativeTypedArrayConstructor), function (key) {
if (!(key in TypedArrayConstructor)) {
createNonEnumerableProperty(TypedArrayConstructor, key, NativeTypedArrayConstructor[key]);
}
});
TypedArrayConstructor.prototype = TypedArrayConstructorPrototype;
}
if (TypedArrayConstructorPrototype.constructor !== TypedArrayConstructor) {
createNonEnumerableProperty(TypedArrayConstructorPrototype, 'constructor', TypedArrayConstructor);
}
if (TYPED_ARRAY_TAG) {
createNonEnumerableProperty(TypedArrayConstructorPrototype, TYPED_ARRAY_TAG, CONSTRUCTOR_NAME);
}
exported[CONSTRUCTOR_NAME] = TypedArrayConstructor;
_export({
global: true,
forced: TypedArrayConstructor != NativeTypedArrayConstructor,
sham: !NATIVE_ARRAY_BUFFER_VIEWS
}, exported);
if (!(BYTES_PER_ELEMENT in TypedArrayConstructor)) {
createNonEnumerableProperty(TypedArrayConstructor, BYTES_PER_ELEMENT, BYTES);
}
if (!(BYTES_PER_ELEMENT in TypedArrayConstructorPrototype)) {
createNonEnumerableProperty(TypedArrayConstructorPrototype, BYTES_PER_ELEMENT, BYTES);
}
setSpecies(CONSTRUCTOR_NAME);
};
} else module.exports = function () {
/* empty */
};
});
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Int8', function (init) {
return function Int8Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_int8Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Uint8', function (init) {
return function Uint8Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_uint8Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Uint8', function (init) {
return function Uint8ClampedArray(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
}, true);
var es_typedArray_uint8ClampedArray = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Int16', function (init) {
return function Int16Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_int16Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Uint16', function (init) {
return function Uint16Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_uint16Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Int32', function (init) {
return function Int32Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_int32Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Uint32', function (init) {
return function Uint32Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_uint32Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Float32', function (init) {
return function Float32Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_float32Array = {};
// https://tc39.es/ecma262/#sec-typedarray-objects
typedArrayConstructor('Float64', function (init) {
return function Float64Array(data, byteOffset, length) {
return init(this, data, byteOffset, length);
};
});
var es_typedArray_float64Array = {};
'use strict';
var exportTypedArrayStaticMethod$1 = arrayBufferViewCore.exportTypedArrayStaticMethod; // `%TypedArray%.from` method
// https://tc39.es/ecma262/#sec-%typedarray%.from
exportTypedArrayStaticMethod$1('from', typedArrayFrom, typedArrayConstructorsRequireWrappers);
var es_typedArray_from = {};
'use strict';
var aTypedArrayConstructor$2 = arrayBufferViewCore.aTypedArrayConstructor;
var exportTypedArrayStaticMethod$2 = arrayBufferViewCore.exportTypedArrayStaticMethod; // `%TypedArray%.of` method
// https://tc39.es/ecma262/#sec-%typedarray%.of
exportTypedArrayStaticMethod$2('of', function of()
/* ...items */
{
var index = 0;
var length = arguments.length;
var result = new (aTypedArrayConstructor$2(this))(length);
while (length > index) {
result[index] = arguments[index++];
}
return result;
}, typedArrayConstructorsRequireWrappers);
var es_typedArray_of = {};
'use strict';
var aTypedArray$1 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$1 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.copyWithin` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.copywithin
exportTypedArrayMethod$1('copyWithin', function copyWithin(target, start
/* , end */
) {
return arrayCopyWithin.call(aTypedArray$1(this), target, start, arguments.length > 2 ? arguments[2] : undefined);
});
var es_typedArray_copyWithin = {};
'use strict';
var $every$1 = arrayIteration.every;
var aTypedArray$2 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$2 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.every` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.every
exportTypedArrayMethod$2('every', function every(callbackfn
/* , thisArg */
) {
return $every$1(aTypedArray$2(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_every = {};
'use strict';
var aTypedArray$3 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$3 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.fill` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.fill
// eslint-disable-next-line no-unused-vars -- required for `.length`
exportTypedArrayMethod$3('fill', function fill(value
/* , start, end */
) {
return arrayFill.apply(aTypedArray$3(this), arguments);
});
var es_typedArray_fill = {};
var aTypedArrayConstructor$3 = arrayBufferViewCore.aTypedArrayConstructor;
var typedArrayFromSpeciesAndList = function typedArrayFromSpeciesAndList(instance, list) {
var C = speciesConstructor(instance, instance.constructor);
var index = 0;
var length = list.length;
var result = new (aTypedArrayConstructor$3(C))(length);
while (length > index) {
result[index] = list[index++];
}
return result;
};
'use strict';
var $filter$1 = arrayIteration.filter;
var aTypedArray$4 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$4 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.filter` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.filter
exportTypedArrayMethod$4('filter', function filter(callbackfn
/* , thisArg */
) {
var list = $filter$1(aTypedArray$4(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
return typedArrayFromSpeciesAndList(this, list);
});
var es_typedArray_filter = {};
'use strict';
var $find$1 = arrayIteration.find;
var aTypedArray$5 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$5 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.find` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.find
exportTypedArrayMethod$5('find', function find(predicate
/* , thisArg */
) {
return $find$1(aTypedArray$5(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_find = {};
'use strict';
var $findIndex$1 = arrayIteration.findIndex;
var aTypedArray$6 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$6 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.findIndex` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.findindex
exportTypedArrayMethod$6('findIndex', function findIndex(predicate
/* , thisArg */
) {
return $findIndex$1(aTypedArray$6(this), predicate, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_findIndex = {};
'use strict';
var $forEach$2 = arrayIteration.forEach;
var aTypedArray$7 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$7 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.forEach` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.foreach
exportTypedArrayMethod$7('forEach', function forEach(callbackfn
/* , thisArg */
) {
$forEach$2(aTypedArray$7(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_forEach = {};
'use strict';
var $includes$1 = arrayIncludes.includes;
var aTypedArray$8 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$8 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.includes` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.includes
exportTypedArrayMethod$8('includes', function includes(searchElement
/* , fromIndex */
) {
return $includes$1(aTypedArray$8(this), searchElement, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_includes = {};
'use strict';
var $indexOf$1 = arrayIncludes.indexOf;
var aTypedArray$9 = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$9 = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.indexOf` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.indexof
exportTypedArrayMethod$9('indexOf', function indexOf(searchElement
/* , fromIndex */
) {
return $indexOf$1(aTypedArray$9(this), searchElement, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_indexOf = {};
'use strict';
var ITERATOR$5 = wellKnownSymbol('iterator');
var Uint8Array$1 = global_1.Uint8Array;
var arrayValues = es_array_iterator.values;
var arrayKeys = es_array_iterator.keys;
var arrayEntries = es_array_iterator.entries;
var aTypedArray$a = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$a = arrayBufferViewCore.exportTypedArrayMethod;
var nativeTypedArrayIterator = Uint8Array$1 && Uint8Array$1.prototype[ITERATOR$5];
var CORRECT_ITER_NAME = !!nativeTypedArrayIterator && (nativeTypedArrayIterator.name == 'values' || nativeTypedArrayIterator.name == undefined);
var typedArrayValues = function values() {
return arrayValues.call(aTypedArray$a(this));
}; // `%TypedArray%.prototype.entries` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.entries
exportTypedArrayMethod$a('entries', function entries() {
return arrayEntries.call(aTypedArray$a(this));
}); // `%TypedArray%.prototype.keys` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.keys
exportTypedArrayMethod$a('keys', function keys() {
return arrayKeys.call(aTypedArray$a(this));
}); // `%TypedArray%.prototype.values` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.values
exportTypedArrayMethod$a('values', typedArrayValues, !CORRECT_ITER_NAME); // `%TypedArray%.prototype[@@iterator]` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype-@@iterator
exportTypedArrayMethod$a(ITERATOR$5, typedArrayValues, !CORRECT_ITER_NAME);
var es_typedArray_iterator = {};
'use strict';
var aTypedArray$b = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$b = arrayBufferViewCore.exportTypedArrayMethod;
var $join = [].join; // `%TypedArray%.prototype.join` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.join
// eslint-disable-next-line no-unused-vars -- required for `.length`
exportTypedArrayMethod$b('join', function join(separator) {
return $join.apply(aTypedArray$b(this), arguments);
});
var es_typedArray_join = {};
'use strict';
var aTypedArray$c = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$c = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.lastIndexOf` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.lastindexof
// eslint-disable-next-line no-unused-vars -- required for `.length`
exportTypedArrayMethod$c('lastIndexOf', function lastIndexOf(searchElement
/* , fromIndex */
) {
return arrayLastIndexOf.apply(aTypedArray$c(this), arguments);
});
var es_typedArray_lastIndexOf = {};
'use strict';
var $map$1 = arrayIteration.map;
var aTypedArray$d = arrayBufferViewCore.aTypedArray;
var aTypedArrayConstructor$4 = arrayBufferViewCore.aTypedArrayConstructor;
var exportTypedArrayMethod$d = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.map` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.map
exportTypedArrayMethod$d('map', function map(mapfn
/* , thisArg */
) {
return $map$1(aTypedArray$d(this), mapfn, arguments.length > 1 ? arguments[1] : undefined, function (O, length) {
return new (aTypedArrayConstructor$4(speciesConstructor(O, O.constructor)))(length);
});
});
var es_typedArray_map = {};
'use strict';
var $reduce$1 = arrayReduce.left;
var aTypedArray$e = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$e = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.reduce` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.reduce
exportTypedArrayMethod$e('reduce', function reduce(callbackfn
/* , initialValue */
) {
return $reduce$1(aTypedArray$e(this), callbackfn, arguments.length, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_reduce = {};
'use strict';
var $reduceRight$1 = arrayReduce.right;
var aTypedArray$f = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$f = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.reduceRicht` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.reduceright
exportTypedArrayMethod$f('reduceRight', function reduceRight(callbackfn
/* , initialValue */
) {
return $reduceRight$1(aTypedArray$f(this), callbackfn, arguments.length, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_reduceRight = {};
'use strict';
var aTypedArray$g = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$g = arrayBufferViewCore.exportTypedArrayMethod;
var floor$7 = Math.floor; // `%TypedArray%.prototype.reverse` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.reverse
exportTypedArrayMethod$g('reverse', function reverse() {
var that = this;
var length = aTypedArray$g(that).length;
var middle = floor$7(length / 2);
var index = 0;
var value;
while (index < middle) {
value = that[index];
that[index++] = that[--length];
that[length] = value;
}
return that;
});
var es_typedArray_reverse = {};
'use strict';
var aTypedArray$h = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$h = arrayBufferViewCore.exportTypedArrayMethod;
var FORCED$h = fails(function () {
/* global Int8Array -- safe */
new Int8Array(1).set({});
}); // `%TypedArray%.prototype.set` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.set
exportTypedArrayMethod$h('set', function set(arrayLike
/* , offset */
) {
aTypedArray$h(this);
var offset = toOffset(arguments.length > 1 ? arguments[1] : undefined, 1);
var length = this.length;
var src = toObject(arrayLike);
var len = toLength(src.length);
var index = 0;
if (len + offset > length) throw RangeError('Wrong length');
while (index < len) {
this[offset + index] = src[index++];
}
}, FORCED$h);
var es_typedArray_set = {};
'use strict';
var aTypedArray$i = arrayBufferViewCore.aTypedArray;
var aTypedArrayConstructor$5 = arrayBufferViewCore.aTypedArrayConstructor;
var exportTypedArrayMethod$i = arrayBufferViewCore.exportTypedArrayMethod;
var $slice = [].slice;
var FORCED$i = fails(function () {
/* global Int8Array -- safe */
new Int8Array(1).slice();
}); // `%TypedArray%.prototype.slice` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.slice
exportTypedArrayMethod$i('slice', function slice(start, end) {
var list = $slice.call(aTypedArray$i(this), start, end);
var C = speciesConstructor(this, this.constructor);
var index = 0;
var length = list.length;
var result = new (aTypedArrayConstructor$5(C))(length);
while (length > index) {
result[index] = list[index++];
}
return result;
}, FORCED$i);
var es_typedArray_slice = {};
'use strict';
var $some$1 = arrayIteration.some;
var aTypedArray$j = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$j = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.some` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.some
exportTypedArrayMethod$j('some', function some(callbackfn
/* , thisArg */
) {
return $some$1(aTypedArray$j(this), callbackfn, arguments.length > 1 ? arguments[1] : undefined);
});
var es_typedArray_some = {};
'use strict';
var aTypedArray$k = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$k = arrayBufferViewCore.exportTypedArrayMethod;
var $sort = [].sort; // `%TypedArray%.prototype.sort` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.sort
exportTypedArrayMethod$k('sort', function sort(comparefn) {
return $sort.call(aTypedArray$k(this), comparefn);
});
var es_typedArray_sort = {};
'use strict';
var aTypedArray$l = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$l = arrayBufferViewCore.exportTypedArrayMethod; // `%TypedArray%.prototype.subarray` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.subarray
exportTypedArrayMethod$l('subarray', function subarray(begin, end) {
var O = aTypedArray$l(this);
var length = O.length;
var beginIndex = toAbsoluteIndex(begin, length);
return new (speciesConstructor(O, O.constructor))(O.buffer, O.byteOffset + beginIndex * O.BYTES_PER_ELEMENT, toLength((end === undefined ? length : toAbsoluteIndex(end, length)) - beginIndex));
});
var es_typedArray_subarray = {};
'use strict';
var Int8Array$3 = global_1.Int8Array;
var aTypedArray$m = arrayBufferViewCore.aTypedArray;
var exportTypedArrayMethod$m = arrayBufferViewCore.exportTypedArrayMethod;
var $toLocaleString = [].toLocaleString;
var $slice$1 = [].slice; // iOS Safari 6.x fails here
var TO_LOCALE_STRING_BUG = !!Int8Array$3 && fails(function () {
$toLocaleString.call(new Int8Array$3(1));
});
var FORCED$j = fails(function () {
return [1, 2].toLocaleString() != new Int8Array$3([1, 2]).toLocaleString();
}) || !fails(function () {
Int8Array$3.prototype.toLocaleString.call([1, 2]);
}); // `%TypedArray%.prototype.toLocaleString` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.tolocalestring
exportTypedArrayMethod$m('toLocaleString', function toLocaleString() {
return $toLocaleString.apply(TO_LOCALE_STRING_BUG ? $slice$1.call(aTypedArray$m(this)) : aTypedArray$m(this), arguments);
}, FORCED$j);
var es_typedArray_toLocaleString = {};
'use strict';
var exportTypedArrayMethod$n = arrayBufferViewCore.exportTypedArrayMethod;
var Uint8Array$2 = global_1.Uint8Array;
var Uint8ArrayPrototype = Uint8Array$2 && Uint8Array$2.prototype || {};
var arrayToString = [].toString;
var arrayJoin = [].join;
if (fails(function () {
arrayToString.call({});
})) {
arrayToString = function toString() {
return arrayJoin.call(this);
};
}
var IS_NOT_ARRAY_METHOD = Uint8ArrayPrototype.toString != arrayToString; // `%TypedArray%.prototype.toString` method
// https://tc39.es/ecma262/#sec-%typedarray%.prototype.tostring
exportTypedArrayMethod$n('toString', arrayToString, IS_NOT_ARRAY_METHOD);
var es_typedArray_toString = {};
var nativeApply = getBuiltIn('Reflect', 'apply');
var functionApply = Function.apply; // MS Edge argumentsList argument is optional
var OPTIONAL_ARGUMENTS_LIST = !fails(function () {
nativeApply(function () {
/* empty */
});
}); // `Reflect.apply` method
// https://tc39.es/ecma262/#sec-reflect.apply
_export({
target: 'Reflect',
stat: true,
forced: OPTIONAL_ARGUMENTS_LIST
}, {
apply: function apply(target, thisArgument, argumentsList) {
aFunction$1(target);
anObject(argumentsList);
return nativeApply ? nativeApply(target, thisArgument, argumentsList) : functionApply.call(target, thisArgument, argumentsList);
}
});
var es_reflect_apply = {};
var nativeConstruct = getBuiltIn('Reflect', 'construct'); // `Reflect.construct` method
// https://tc39.es/ecma262/#sec-reflect.construct
// MS Edge supports only 2 arguments and argumentsList argument is optional
// FF Nightly sets third argument as `new.target`, but does not create `this` from it
var NEW_TARGET_BUG = fails(function () {
function F() {
/* empty */
}
return !(nativeConstruct(function () {
/* empty */
}, [], F) instanceof F);
});
var ARGS_BUG = !fails(function () {
nativeConstruct(function () {
/* empty */
});
});
var FORCED$k = NEW_TARGET_BUG || ARGS_BUG;
_export({
target: 'Reflect',
stat: true,
forced: FORCED$k,
sham: FORCED$k
}, {
construct: function construct(Target, args
/* , newTarget */
) {
aFunction$1(Target);
anObject(args);
var newTarget = arguments.length < 3 ? Target : aFunction$1(arguments[2]);
if (ARGS_BUG && !NEW_TARGET_BUG) return nativeConstruct(Target, args, newTarget);
if (Target == newTarget) {
// w/o altered newTarget, optimization for 0-4 arguments
switch (args.length) {
case 0:
return new Target();
case 1:
return new Target(args[0]);
case 2:
return new Target(args[0], args[1]);
case 3:
return new Target(args[0], args[1], args[2]);
case 4:
return new Target(args[0], args[1], args[2], args[3]);
} // w/o altered newTarget, lot of arguments case
var $args = [null];
$args.push.apply($args, args);
return new (functionBind.apply(Target, $args))();
} // with altered newTarget, not support built-in constructors
var proto = newTarget.prototype;
var instance = objectCreate(isObject(proto) ? proto : Object.prototype);
var result = Function.apply.call(Target, instance, args);
return isObject(result) ? result : instance;
}
});
var es_reflect_construct = {};
var ERROR_INSTEAD_OF_FALSE = fails(function () {
/* global Reflect -- required for testing */
Reflect.defineProperty(objectDefineProperty.f({}, 1, {
value: 1
}), 1, {
value: 2
});
}); // `Reflect.defineProperty` method
// https://tc39.es/ecma262/#sec-reflect.defineproperty
_export({
target: 'Reflect',
stat: true,
forced: ERROR_INSTEAD_OF_FALSE,
sham: !descriptors
}, {
defineProperty: function defineProperty(target, propertyKey, attributes) {
anObject(target);
var key = toPrimitive(propertyKey, true);
anObject(attributes);
try {
objectDefineProperty.f(target, key, attributes);
return true;
} catch (error) {
return false;
}
}
});
var es_reflect_defineProperty = {};
var getOwnPropertyDescriptor$8 = objectGetOwnPropertyDescriptor.f; // `Reflect.deleteProperty` method
// https://tc39.es/ecma262/#sec-reflect.deleteproperty
_export({
target: 'Reflect',
stat: true
}, {
deleteProperty: function deleteProperty(target, propertyKey) {
var descriptor = getOwnPropertyDescriptor$8(anObject(target), propertyKey);
return descriptor && !descriptor.configurable ? false : delete target[propertyKey];
}
});
var es_reflect_deleteProperty = {};
// https://tc39.es/ecma262/#sec-reflect.get
function get$2(target, propertyKey
/* , receiver */
) {
var receiver = arguments.length < 3 ? target : arguments[2];
var descriptor, prototype;
if (anObject(target) === receiver) return target[propertyKey];
if (descriptor = objectGetOwnPropertyDescriptor.f(target, propertyKey)) return has(descriptor, 'value') ? descriptor.value : descriptor.get === undefined ? undefined : descriptor.get.call(receiver);
if (isObject(prototype = objectGetPrototypeOf(target))) return get$2(prototype, propertyKey, receiver);
}
_export({
target: 'Reflect',
stat: true
}, {
get: get$2
});
var es_reflect_get = {};
// https://tc39.es/ecma262/#sec-reflect.getownpropertydescriptor
_export({
target: 'Reflect',
stat: true,
sham: !descriptors
}, {
getOwnPropertyDescriptor: function getOwnPropertyDescriptor(target, propertyKey) {
return objectGetOwnPropertyDescriptor.f(anObject(target), propertyKey);
}
});
var es_reflect_getOwnPropertyDescriptor = {};
// https://tc39.es/ecma262/#sec-reflect.getprototypeof
_export({
target: 'Reflect',
stat: true,
sham: !correctPrototypeGetter
}, {
getPrototypeOf: function getPrototypeOf(target) {
return objectGetPrototypeOf(anObject(target));
}
});
var es_reflect_getPrototypeOf = {};
// https://tc39.es/ecma262/#sec-reflect.has
_export({
target: 'Reflect',
stat: true
}, {
has: function has(target, propertyKey) {
return propertyKey in target;
}
});
var es_reflect_has = {};
var objectIsExtensible = Object.isExtensible; // `Reflect.isExtensible` method
// https://tc39.es/ecma262/#sec-reflect.isextensible
_export({
target: 'Reflect',
stat: true
}, {
isExtensible: function isExtensible(target) {
anObject(target);
return objectIsExtensible ? objectIsExtensible(target) : true;
}
});
var es_reflect_isExtensible = {};
// https://tc39.es/ecma262/#sec-reflect.ownkeys
_export({
target: 'Reflect',
stat: true
}, {
ownKeys: ownKeys
});
var es_reflect_ownKeys = {};
// https://tc39.es/ecma262/#sec-reflect.preventextensions
_export({
target: 'Reflect',
stat: true,
sham: !freezing
}, {
preventExtensions: function preventExtensions(target) {
anObject(target);
try {
var objectPreventExtensions = getBuiltIn('Object', 'preventExtensions');
if (objectPreventExtensions) objectPreventExtensions(target);
return true;
} catch (error) {
return false;
}
}
});
var es_reflect_preventExtensions = {};
// https://tc39.es/ecma262/#sec-reflect.set
function set$3(target, propertyKey, V
/* , receiver */
) {
var receiver = arguments.length < 4 ? target : arguments[3];
var ownDescriptor = objectGetOwnPropertyDescriptor.f(anObject(target), propertyKey);
var existingDescriptor, prototype;
if (!ownDescriptor) {
if (isObject(prototype = objectGetPrototypeOf(target))) {
return set$3(prototype, propertyKey, V, receiver);
}
ownDescriptor = createPropertyDescriptor(0);
}
if (has(ownDescriptor, 'value')) {
if (ownDescriptor.writable === false || !isObject(receiver)) return false;
if (existingDescriptor = objectGetOwnPropertyDescriptor.f(receiver, propertyKey)) {
if (existingDescriptor.get || existingDescriptor.set || existingDescriptor.writable === false) return false;
existingDescriptor.value = V;
objectDefineProperty.f(receiver, propertyKey, existingDescriptor);
} else objectDefineProperty.f(receiver, propertyKey, createPropertyDescriptor(0, V));
return true;
}
return ownDescriptor.set === undefined ? false : (ownDescriptor.set.call(receiver, V), true);
} // MS Edge 17-18 Reflect.set allows setting the property to object
// with non-writable property on the prototype
var MS_EDGE_BUG = fails(function () {
var Constructor = function Constructor() {
/* empty */
};
var object = objectDefineProperty.f(new Constructor(), 'a', {
configurable: true
});
/* global Reflect -- required for testing */
return Reflect.set(Constructor.prototype, 'a', 1, object) !== false;
});
_export({
target: 'Reflect',
stat: true,
forced: MS_EDGE_BUG
}, {
set: set$3
});
var es_reflect_set = {};
// https://tc39.es/ecma262/#sec-reflect.setprototypeof
if (objectSetPrototypeOf) _export({
target: 'Reflect',
stat: true
}, {
setPrototypeOf: function setPrototypeOf(target, proto) {
anObject(target);
aPossiblePrototype(proto);
try {
objectSetPrototypeOf(target, proto);
return true;
} catch (error) {
return false;
}
}
});
var es_reflect_setPrototypeOf = {};
_export({
global: true
}, {
Reflect: {}
}); // Reflect[@@toStringTag] property
// https://tc39.es/ecma262/#sec-reflect-@@tostringtag
setToStringTag(global_1.Reflect, 'Reflect', true);
var es_reflect_toStringTag = {};
var es = path;
// iterable DOM collections
// flag - `iterable` interface - 'entries', 'keys', 'values', 'forEach' methods
var domIterables = {
CSSRuleList: 0,
CSSStyleDeclaration: 0,
CSSValueList: 0,
ClientRectList: 0,
DOMRectList: 0,
DOMStringList: 0,
DOMTokenList: 1,
DataTransferItemList: 0,
FileList: 0,
HTMLAllCollection: 0,
HTMLCollection: 0,
HTMLFormElement: 0,
HTMLSelectElement: 0,
MediaList: 0,
MimeTypeArray: 0,
NamedNodeMap: 0,
NodeList: 1,
PaintRequestList: 0,
Plugin: 0,
PluginArray: 0,
SVGLengthList: 0,
SVGNumberList: 0,
SVGPathSegList: 0,
SVGPointList: 0,
SVGStringList: 0,
SVGTransformList: 0,
SourceBufferList: 0,
StyleSheetList: 0,
TextTrackCueList: 0,
TextTrackList: 0,
TouchList: 0
};
var domIterables_1 = domIterables.CSSRuleList;
var domIterables_2 = domIterables.CSSStyleDeclaration;
var domIterables_3 = domIterables.CSSValueList;
var domIterables_4 = domIterables.ClientRectList;
var domIterables_5 = domIterables.DOMRectList;
var domIterables_6 = domIterables.DOMStringList;
var domIterables_7 = domIterables.DOMTokenList;
var domIterables_8 = domIterables.DataTransferItemList;
var domIterables_9 = domIterables.FileList;
var domIterables_10 = domIterables.HTMLAllCollection;
var domIterables_11 = domIterables.HTMLCollection;
var domIterables_12 = domIterables.HTMLFormElement;
var domIterables_13 = domIterables.HTMLSelectElement;
var domIterables_14 = domIterables.MediaList;
var domIterables_15 = domIterables.MimeTypeArray;
var domIterables_16 = domIterables.NamedNodeMap;
var domIterables_17 = domIterables.NodeList;
var domIterables_18 = domIterables.PaintRequestList;
var domIterables_19 = domIterables.Plugin;
var domIterables_20 = domIterables.PluginArray;
var domIterables_21 = domIterables.SVGLengthList;
var domIterables_22 = domIterables.SVGNumberList;
var domIterables_23 = domIterables.SVGPathSegList;
var domIterables_24 = domIterables.SVGPointList;
var domIterables_25 = domIterables.SVGStringList;
var domIterables_26 = domIterables.SVGTransformList;
var domIterables_27 = domIterables.SourceBufferList;
var domIterables_28 = domIterables.StyleSheetList;
var domIterables_29 = domIterables.TextTrackCueList;
var domIterables_30 = domIterables.TextTrackList;
var domIterables_31 = domIterables.TouchList;
for (var COLLECTION_NAME in domIterables) {
var Collection = global_1[COLLECTION_NAME];
var CollectionPrototype = Collection && Collection.prototype; // some Chrome versions have non-configurable methods on DOMTokenList
if (CollectionPrototype && CollectionPrototype.forEach !== arrayForEach) try {
createNonEnumerableProperty(CollectionPrototype, 'forEach', arrayForEach);
} catch (error) {
CollectionPrototype.forEach = arrayForEach;
}
}
var web_domCollections_forEach = {};
var ITERATOR$6 = wellKnownSymbol('iterator');
var TO_STRING_TAG$4 = wellKnownSymbol('toStringTag');
var ArrayValues = es_array_iterator.values;
for (var COLLECTION_NAME$1 in domIterables) {
var Collection$1 = global_1[COLLECTION_NAME$1];
var CollectionPrototype$1 = Collection$1 && Collection$1.prototype;
if (CollectionPrototype$1) {
// some Chrome versions have non-configurable methods on DOMTokenList
if (CollectionPrototype$1[ITERATOR$6] !== ArrayValues) try {
createNonEnumerableProperty(CollectionPrototype$1, ITERATOR$6, ArrayValues);
} catch (error) {
CollectionPrototype$1[ITERATOR$6] = ArrayValues;
}
if (!CollectionPrototype$1[TO_STRING_TAG$4]) {
createNonEnumerableProperty(CollectionPrototype$1, TO_STRING_TAG$4, COLLECTION_NAME$1);
}
if (domIterables[COLLECTION_NAME$1]) for (var METHOD_NAME in es_array_iterator) {
// some Chrome versions have non-configurable methods on DOMTokenList
if (CollectionPrototype$1[METHOD_NAME] !== es_array_iterator[METHOD_NAME]) try {
createNonEnumerableProperty(CollectionPrototype$1, METHOD_NAME, es_array_iterator[METHOD_NAME]);
} catch (error) {
CollectionPrototype$1[METHOD_NAME] = es_array_iterator[METHOD_NAME];
}
}
}
}
var web_domCollections_iterator = {};
var FORCED$l = !global_1.setImmediate || !global_1.clearImmediate; // http://w3c.github.io/setImmediate/
_export({
global: true,
bind: true,
enumerable: true,
forced: FORCED$l
}, {
// `setImmediate` method
// http://w3c.github.io/setImmediate/#si-setImmediate
setImmediate: task.set,
// `clearImmediate` method
// http://w3c.github.io/setImmediate/#si-clearImmediate
clearImmediate: task.clear
});
var web_immediate = {};
var process$5 = global_1.process; // `queueMicrotask` method
// https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-queuemicrotask
_export({
global: true,
enumerable: true,
noTargetGet: true
}, {
queueMicrotask: function queueMicrotask(fn) {
var domain = engineIsNode && process$5.domain;
microtask(domain ? domain.bind(fn) : fn);
}
});
var web_queueMicrotask = {};
var slice$1 = [].slice;
var MSIE = /MSIE .\./.test(engineUserAgent); // <- dirty ie9- check
var wrap$1 = function wrap(scheduler) {
return function (handler, timeout
/* , ...arguments */
) {
var boundArgs = arguments.length > 2;
var args = boundArgs ? slice$1.call(arguments, 2) : undefined;
return scheduler(boundArgs ? function () {
// eslint-disable-next-line no-new-func -- spec requirement
(typeof handler == 'function' ? handler : Function(handler)).apply(this, args);
} : handler, timeout);
};
}; // ie9- setTimeout & setInterval additional parameters fix
// https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#timers
_export({
global: true,
bind: true,
forced: MSIE
}, {
// `setTimeout` method
// https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-settimeout
setTimeout: wrap$1(global_1.setTimeout),
// `setInterval` method
// https://html.spec.whatwg.org/multipage/timers-and-user-prompts.html#dom-setinterval
setInterval: wrap$1(global_1.setInterval)
});
var web_timers = {};
var ITERATOR$7 = wellKnownSymbol('iterator');
var nativeUrl = !fails(function () {
var url = new URL('b?a=1&b=2&c=3', 'http://a');
var searchParams = url.searchParams;
var result = '';
url.pathname = 'c%20d';
searchParams.forEach(function (value, key) {
searchParams['delete']('b');
result += key + value;
});
return isPure && !url.toJSON || !searchParams.sort || url.href !== 'http://a/c%20d?a=1&c=3' || searchParams.get('c') !== '3' || String(new URLSearchParams('?a=1')) !== 'a=1' || !searchParams[ITERATOR$7] // throws in Edge
|| new URL('https://a@b').username !== 'a' || new URLSearchParams(new URLSearchParams('a=b')).get('a') !== 'b' // not punycoded in Edge
|| new URL('http://тест').host !== 'xn--e1aybc' // not escaped in Chrome 62-
|| new URL('http://a#б').hash !== '#%D0%B1' // fails in Chrome 66-
|| result !== 'a1c3' // throws in Safari
|| new URL('http://x', undefined).host !== 'x';
});
'use strict'; // based on https://github.com/bestiejs/punycode.js/blob/master/punycode.js
var maxInt = 2147483647; // aka. 0x7FFFFFFF or 2^31-1
var base = 36;
var tMin = 1;
var tMax = 26;
var skew = 38;
var damp = 700;
var initialBias = 72;
var initialN = 128; // 0x80
var delimiter = '-'; // '\x2D'
var regexNonASCII = /[^\0-\u007E]/; // non-ASCII chars
var regexSeparators = /[.\u3002\uFF0E\uFF61]/g; // RFC 3490 separators
var OVERFLOW_ERROR = 'Overflow: input needs wider integers to process';
var baseMinusTMin = base - tMin;
var floor$8 = Math.floor;
var stringFromCharCode = String.fromCharCode;
/**
* Creates an array containing the numeric code points of each Unicode
* character in the string. While JavaScript uses UCS-2 internally,
* this function will convert a pair of surrogate halves (each of which
* UCS-2 exposes as separate characters) into a single code point,
* matching UTF-16.
*/
var ucs2decode = function ucs2decode(string) {
var output = [];
var counter = 0;
var length = string.length;
while (counter < length) {
var value = string.charCodeAt(counter++);
if (value >= 0xD800 && value <= 0xDBFF && counter < length) {
// It's a high surrogate, and there is a next character.
var extra = string.charCodeAt(counter++);
if ((extra & 0xFC00) == 0xDC00) {
// Low surrogate.
output.push(((value & 0x3FF) << 10) + (extra & 0x3FF) + 0x10000);
} else {
// It's an unmatched surrogate; only append this code unit, in case the
// next code unit is the high surrogate of a surrogate pair.
output.push(value);
counter--;
}
} else {
output.push(value);
}
}
return output;
};
/**
* Converts a digit/integer into a basic code point.
*/
var digitToBasic = function digitToBasic(digit) {
// 0..25 map to ASCII a..z or A..Z
// 26..35 map to ASCII 0..9
return digit + 22 + 75 * (digit < 26);
};
/**
* Bias adaptation function as per section 3.4 of RFC 3492.
* https://tools.ietf.org/html/rfc3492#section-3.4
*/
var adapt = function adapt(delta, numPoints, firstTime) {
var k = 0;
delta = firstTime ? floor$8(delta / damp) : delta >> 1;
delta += floor$8(delta / numPoints);
for (; delta > baseMinusTMin * tMax >> 1; k += base) {
delta = floor$8(delta / baseMinusTMin);
}
return floor$8(k + (baseMinusTMin + 1) * delta / (delta + skew));
};
/**
* Converts a string of Unicode symbols (e.g. a domain name label) to a
* Punycode string of ASCII-only symbols.
*/
// eslint-disable-next-line max-statements -- TODO
var encode = function encode(input) {
var output = []; // Convert the input in UCS-2 to an array of Unicode code points.
input = ucs2decode(input); // Cache the length.
var inputLength = input.length; // Initialize the state.
var n = initialN;
var delta = 0;
var bias = initialBias;
var i, currentValue; // Handle the basic code points.
for (i = 0; i < input.length; i++) {
currentValue = input[i];
if (currentValue < 0x80) {
output.push(stringFromCharCode(currentValue));
}
}
var basicLength = output.length; // number of basic code points.
var handledCPCount = basicLength; // number of code points that have been handled;
// Finish the basic string with a delimiter unless it's empty.
if (basicLength) {
output.push(delimiter);
} // Main encoding loop:
while (handledCPCount < inputLength) {
// All non-basic code points < n have been handled already. Find the next larger one:
var m = maxInt;
for (i = 0; i < input.length; i++) {
currentValue = input[i];
if (currentValue >= n && currentValue < m) {
m = currentValue;
}
} // Increase `delta` enough to advance the decoder's <n,i> state to <m,0>, but guard against overflow.
var handledCPCountPlusOne = handledCPCount + 1;
if (m - n > floor$8((maxInt - delta) / handledCPCountPlusOne)) {
throw RangeError(OVERFLOW_ERROR);
}
delta += (m - n) * handledCPCountPlusOne;
n = m;
for (i = 0; i < input.length; i++) {
currentValue = input[i];
if (currentValue < n && ++delta > maxInt) {
throw RangeError(OVERFLOW_ERROR);
}
if (currentValue == n) {
// Represent delta as a generalized variable-length integer.
var q = delta;
for (var k = base;;
/* no condition */
k += base) {
var t = k <= bias ? tMin : k >= bias + tMax ? tMax : k - bias;
if (q < t) break;
var qMinusT = q - t;
var baseMinusT = base - t;
output.push(stringFromCharCode(digitToBasic(t + qMinusT % baseMinusT)));
q = floor$8(qMinusT / baseMinusT);
}
output.push(stringFromCharCode(digitToBasic(q)));
bias = adapt(delta, handledCPCountPlusOne, handledCPCount == basicLength);
delta = 0;
++handledCPCount;
}
}
++delta;
++n;
}
return output.join('');
};
var stringPunycodeToAscii = function stringPunycodeToAscii(input) {
var encoded = [];
var labels = input.toLowerCase().replace(regexSeparators, ".").split('.');
var i, label;
for (i = 0; i < labels.length; i++) {
label = labels[i];
encoded.push(regexNonASCII.test(label) ? 'xn--' + encode(label) : label);
}
return encoded.join('.');
};
var getIterator = function getIterator(it) {
var iteratorMethod = getIteratorMethod(it);
if (typeof iteratorMethod != 'function') {
throw TypeError(String(it) + ' is not iterable');
}
return anObject(iteratorMethod.call(it));
};
'use strict'; // TODO: in core-js@4, move /modules/ dependencies to public entries for better optimization by tools like `preset-env`
var $fetch$1 = getBuiltIn('fetch');
var Headers = getBuiltIn('Headers');
var ITERATOR$8 = wellKnownSymbol('iterator');
var URL_SEARCH_PARAMS = 'URLSearchParams';
var URL_SEARCH_PARAMS_ITERATOR = URL_SEARCH_PARAMS + 'Iterator';
var setInternalState$9 = internalState.set;
var getInternalParamsState = internalState.getterFor(URL_SEARCH_PARAMS);
var getInternalIteratorState = internalState.getterFor(URL_SEARCH_PARAMS_ITERATOR);
var plus = /\+/g;
var sequences = Array(4);
var percentSequence = function percentSequence(bytes) {
return sequences[bytes - 1] || (sequences[bytes - 1] = RegExp('((?:%[\\da-f]{2}){' + bytes + '})', 'gi'));
};
var percentDecode = function percentDecode(sequence) {
try {
return decodeURIComponent(sequence);
} catch (error) {
return sequence;
}
};
var deserialize = function deserialize(it) {
var result = it.replace(plus, ' ');
var bytes = 4;
try {
return decodeURIComponent(result);
} catch (error) {
while (bytes) {
result = result.replace(percentSequence(bytes--), percentDecode);
}
return result;
}
};
var find$1 = /[!'()~]|%20/g;
var replace$1 = {
'!': '%21',
"'": '%27',
'(': '%28',
')': '%29',
'~': '%7E',
'%20': '+'
};
var replacer = function replacer(match) {
return replace$1[match];
};
var serialize = function serialize(it) {
return encodeURIComponent(it).replace(find$1, replacer);
};
var parseSearchParams = function parseSearchParams(result, query) {
if (query) {
var attributes = query.split('&');
var index = 0;
var attribute, entry;
while (index < attributes.length) {
attribute = attributes[index++];
if (attribute.length) {
entry = attribute.split('=');
result.push({
key: deserialize(entry.shift()),
value: deserialize(entry.join('='))
});
}
}
}
};
var updateSearchParams = function updateSearchParams(query) {
this.entries.length = 0;
parseSearchParams(this.entries, query);
};
var validateArgumentsLength = function validateArgumentsLength(passed, required) {
if (passed < required) throw TypeError('Not enough arguments');
};
var URLSearchParamsIterator = createIteratorConstructor(function Iterator(params, kind) {
setInternalState$9(this, {
type: URL_SEARCH_PARAMS_ITERATOR,
iterator: getIterator(getInternalParamsState(params).entries),
kind: kind
});
}, 'Iterator', function next() {
var state = getInternalIteratorState(this);
var kind = state.kind;
var step = state.iterator.next();
var entry = step.value;
if (!step.done) {
step.value = kind === 'keys' ? entry.key : kind === 'values' ? entry.value : [entry.key, entry.value];
}
return step;
}); // `URLSearchParams` constructor
// https://url.spec.whatwg.org/#interface-urlsearchparams
var URLSearchParamsConstructor = function URLSearchParams()
/* init */
{
anInstance(this, URLSearchParamsConstructor, URL_SEARCH_PARAMS);
var init = arguments.length > 0 ? arguments[0] : undefined;
var that = this;
var entries = [];
var iteratorMethod, iterator, next, step, entryIterator, entryNext, first, second, key;
setInternalState$9(that, {
type: URL_SEARCH_PARAMS,
entries: entries,
updateURL: function updateURL() {
/* empty */
},
updateSearchParams: updateSearchParams
});
if (init !== undefined) {
if (isObject(init)) {
iteratorMethod = getIteratorMethod(init);
if (typeof iteratorMethod === 'function') {
iterator = iteratorMethod.call(init);
next = iterator.next;
while (!(step = next.call(iterator)).done) {
entryIterator = getIterator(anObject(step.value));
entryNext = entryIterator.next;
if ((first = entryNext.call(entryIterator)).done || (second = entryNext.call(entryIterator)).done || !entryNext.call(entryIterator).done) throw TypeError('Expected sequence with length 2');
entries.push({
key: first.value + '',
value: second.value + ''
});
}
} else for (key in init) {
if (has(init, key)) entries.push({
key: key,
value: init[key] + ''
});
}
} else {
parseSearchParams(entries, typeof init === 'string' ? init.charAt(0) === '?' ? init.slice(1) : init : init + '');
}
}
};
var URLSearchParamsPrototype = URLSearchParamsConstructor.prototype;
redefineAll(URLSearchParamsPrototype, {
// `URLSearchParams.prototype.append` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-append
append: function append(name, value) {
validateArgumentsLength(arguments.length, 2);
var state = getInternalParamsState(this);
state.entries.push({
key: name + '',
value: value + ''
});
state.updateURL();
},
// `URLSearchParams.prototype.delete` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-delete
'delete': function _delete(name) {
validateArgumentsLength(arguments.length, 1);
var state = getInternalParamsState(this);
var entries = state.entries;
var key = name + '';
var index = 0;
while (index < entries.length) {
if (entries[index].key === key) entries.splice(index, 1);else index++;
}
state.updateURL();
},
// `URLSearchParams.prototype.get` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-get
get: function get(name) {
validateArgumentsLength(arguments.length, 1);
var entries = getInternalParamsState(this).entries;
var key = name + '';
var index = 0;
for (; index < entries.length; index++) {
if (entries[index].key === key) return entries[index].value;
}
return null;
},
// `URLSearchParams.prototype.getAll` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-getall
getAll: function getAll(name) {
validateArgumentsLength(arguments.length, 1);
var entries = getInternalParamsState(this).entries;
var key = name + '';
var result = [];
var index = 0;
for (; index < entries.length; index++) {
if (entries[index].key === key) result.push(entries[index].value);
}
return result;
},
// `URLSearchParams.prototype.has` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-has
has: function has(name) {
validateArgumentsLength(arguments.length, 1);
var entries = getInternalParamsState(this).entries;
var key = name + '';
var index = 0;
while (index < entries.length) {
if (entries[index++].key === key) return true;
}
return false;
},
// `URLSearchParams.prototype.set` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-set
set: function set(name, value) {
validateArgumentsLength(arguments.length, 1);
var state = getInternalParamsState(this);
var entries = state.entries;
var found = false;
var key = name + '';
var val = value + '';
var index = 0;
var entry;
for (; index < entries.length; index++) {
entry = entries[index];
if (entry.key === key) {
if (found) entries.splice(index--, 1);else {
found = true;
entry.value = val;
}
}
}
if (!found) entries.push({
key: key,
value: val
});
state.updateURL();
},
// `URLSearchParams.prototype.sort` method
// https://url.spec.whatwg.org/#dom-urlsearchparams-sort
sort: function sort() {
var state = getInternalParamsState(this);
var entries = state.entries; // Array#sort is not stable in some engines
var slice = entries.slice();
var entry, entriesIndex, sliceIndex;
entries.length = 0;
for (sliceIndex = 0; sliceIndex < slice.length; sliceIndex++) {
entry = slice[sliceIndex];
for (entriesIndex = 0; entriesIndex < sliceIndex; entriesIndex++) {
if (entries[entriesIndex].key > entry.key) {
entries.splice(entriesIndex, 0, entry);
break;
}
}
if (entriesIndex === sliceIndex) entries.push(entry);
}
state.updateURL();
},
// `URLSearchParams.prototype.forEach` method
forEach: function forEach(callback
/* , thisArg */
) {
var entries = getInternalParamsState(this).entries;
var boundFunction = functionBindContext(callback, arguments.length > 1 ? arguments[1] : undefined, 3);
var index = 0;
var entry;
while (index < entries.length) {
entry = entries[index++];
boundFunction(entry.value, entry.key, this);
}
},
// `URLSearchParams.prototype.keys` method
keys: function keys() {
return new URLSearchParamsIterator(this, 'keys');
},
// `URLSearchParams.prototype.values` method
values: function values() {
return new URLSearchParamsIterator(this, 'values');
},
// `URLSearchParams.prototype.entries` method
entries: function entries() {
return new URLSearchParamsIterator(this, 'entries');
}
}, {
enumerable: true
}); // `URLSearchParams.prototype[@@iterator]` method
redefine(URLSearchParamsPrototype, ITERATOR$8, URLSearchParamsPrototype.entries); // `URLSearchParams.prototype.toString` method
// https://url.spec.whatwg.org/#urlsearchparams-stringification-behavior
redefine(URLSearchParamsPrototype, 'toString', function toString() {
var entries = getInternalParamsState(this).entries;
var result = [];
var index = 0;
var entry;
while (index < entries.length) {
entry = entries[index++];
result.push(serialize(entry.key) + '=' + serialize(entry.value));
}
return result.join('&');
}, {
enumerable: true
});
setToStringTag(URLSearchParamsConstructor, URL_SEARCH_PARAMS);
_export({
global: true,
forced: !nativeUrl
}, {
URLSearchParams: URLSearchParamsConstructor
}); // Wrap `fetch` for correct work with polyfilled `URLSearchParams`
// https://github.com/zloirock/core-js/issues/674
if (!nativeUrl && typeof $fetch$1 == 'function' && typeof Headers == 'function') {
_export({
global: true,
enumerable: true,
forced: true
}, {
fetch: function fetch(input
/* , init */
) {
var args = [input];
var init, body, headers;
if (arguments.length > 1) {
init = arguments[1];
if (isObject(init)) {
body = init.body;
if (classof(body) === URL_SEARCH_PARAMS) {
headers = init.headers ? new Headers(init.headers) : new Headers();
if (!headers.has('content-type')) {
headers.set('content-type', 'application/x-www-form-urlencoded;charset=UTF-8');
}
init = objectCreate(init, {
body: createPropertyDescriptor(0, String(body)),
headers: createPropertyDescriptor(0, headers)
});
}
}
args.push(init);
}
return $fetch$1.apply(this, args);
}
});
}
var web_urlSearchParams = {
URLSearchParams: URLSearchParamsConstructor,
getState: getInternalParamsState
};
var web_urlSearchParams_1 = web_urlSearchParams.URLSearchParams;
var web_urlSearchParams_2 = web_urlSearchParams.getState;
'use strict'; // TODO: in core-js@4, move /modules/ dependencies to public entries for better optimization by tools like `preset-env`
var codeAt$1 = stringMultibyte.codeAt;
var NativeURL = global_1.URL;
var URLSearchParams$1 = web_urlSearchParams.URLSearchParams;
var getInternalSearchParamsState = web_urlSearchParams.getState;
var setInternalState$a = internalState.set;
var getInternalURLState = internalState.getterFor('URL');
var floor$9 = Math.floor;
var pow$4 = Math.pow;
var INVALID_AUTHORITY = 'Invalid authority';
var INVALID_SCHEME = 'Invalid scheme';
var INVALID_HOST = 'Invalid host';
var INVALID_PORT = 'Invalid port';
var ALPHA = /[A-Za-z]/;
var ALPHANUMERIC = /[\d+-.A-Za-z]/;
var DIGIT = /\d/;
var HEX_START = /^(0x|0X)/;
var OCT = /^[0-7]+$/;
var DEC = /^\d+$/;
var HEX = /^[\dA-Fa-f]+$/;
/* eslint-disable no-control-regex -- safe */
var FORBIDDEN_HOST_CODE_POINT = /[\u0000\t\u000A\u000D #%/:?@[\\]]/;
var FORBIDDEN_HOST_CODE_POINT_EXCLUDING_PERCENT = /[\u0000\t\u000A\u000D #/:?@[\\]]/;
var LEADING_AND_TRAILING_C0_CONTROL_OR_SPACE = /^[\u0000-\u001F ]+|[\u0000-\u001F ]+$/g;
var TAB_AND_NEW_LINE = /[\t\u000A\u000D]/g;
/* eslint-enable no-control-regex -- safe */
var EOF;
var parseHost = function parseHost(url, input) {
var result, codePoints, index;
if (input.charAt(0) == '[') {
if (input.charAt(input.length - 1) != ']') return INVALID_HOST;
result = parseIPv6(input.slice(1, -1));
if (!result) return INVALID_HOST;
url.host = result; // opaque host
} else if (!isSpecial(url)) {
if (FORBIDDEN_HOST_CODE_POINT_EXCLUDING_PERCENT.test(input)) return INVALID_HOST;
result = '';
codePoints = arrayFrom(input);
for (index = 0; index < codePoints.length; index++) {
result += percentEncode(codePoints[index], C0ControlPercentEncodeSet);
}
url.host = result;
} else {
input = stringPunycodeToAscii(input);
if (FORBIDDEN_HOST_CODE_POINT.test(input)) return INVALID_HOST;
result = parseIPv4(input);
if (result === null) return INVALID_HOST;
url.host = result;
}
};
var parseIPv4 = function parseIPv4(input) {
var parts = input.split('.');
var partsLength, numbers, index, part, radix, number, ipv4;
if (parts.length && parts[parts.length - 1] == '') {
parts.pop();
}
partsLength = parts.length;
if (partsLength > 4) return input;
numbers = [];
for (index = 0; index < partsLength; index++) {
part = parts[index];
if (part == '') return input;
radix = 10;
if (part.length > 1 && part.charAt(0) == '0') {
radix = HEX_START.test(part) ? 16 : 8;
part = part.slice(radix == 8 ? 1 : 2);
}
if (part === '') {
number = 0;
} else {
if (!(radix == 10 ? DEC : radix == 8 ? OCT : HEX).test(part)) return input;
number = parseInt(part, radix);
}
numbers.push(number);
}
for (index = 0; index < partsLength; index++) {
number = numbers[index];
if (index == partsLength - 1) {
if (number >= pow$4(256, 5 - partsLength)) return null;
} else if (number > 255) return null;
}
ipv4 = numbers.pop();
for (index = 0; index < numbers.length; index++) {
ipv4 += numbers[index] * pow$4(256, 3 - index);
}
return ipv4;
}; // eslint-disable-next-line max-statements -- TODO
var parseIPv6 = function parseIPv6(input) {
var address = [0, 0, 0, 0, 0, 0, 0, 0];
var pieceIndex = 0;
var compress = null;
var pointer = 0;
var value, length, numbersSeen, ipv4Piece, number, swaps, swap;
var char = function char() {
return input.charAt(pointer);
};
if (char() == ':') {
if (input.charAt(1) != ':') return;
pointer += 2;
pieceIndex++;
compress = pieceIndex;
}
while (char()) {
if (pieceIndex == 8) return;
if (char() == ':') {
if (compress !== null) return;
pointer++;
pieceIndex++;
compress = pieceIndex;
continue;
}
value = length = 0;
while (length < 4 && HEX.test(char())) {
value = value * 16 + parseInt(char(), 16);
pointer++;
length++;
}
if (char() == '.') {
if (length == 0) return;
pointer -= length;
if (pieceIndex > 6) return;
numbersSeen = 0;
while (char()) {
ipv4Piece = null;
if (numbersSeen > 0) {
if (char() == '.' && numbersSeen < 4) pointer++;else return;
}
if (!DIGIT.test(char())) return;
while (DIGIT.test(char())) {
number = parseInt(char(), 10);
if (ipv4Piece === null) ipv4Piece = number;else if (ipv4Piece == 0) return;else ipv4Piece = ipv4Piece * 10 + number;
if (ipv4Piece > 255) return;
pointer++;
}
address[pieceIndex] = address[pieceIndex] * 256 + ipv4Piece;
numbersSeen++;
if (numbersSeen == 2 || numbersSeen == 4) pieceIndex++;
}
if (numbersSeen != 4) return;
break;
} else if (char() == ':') {
pointer++;
if (!char()) return;
} else if (char()) return;
address[pieceIndex++] = value;
}
if (compress !== null) {
swaps = pieceIndex - compress;
pieceIndex = 7;
while (pieceIndex != 0 && swaps > 0) {
swap = address[pieceIndex];
address[pieceIndex--] = address[compress + swaps - 1];
address[compress + --swaps] = swap;
}
} else if (pieceIndex != 8) return;
return address;
};
var findLongestZeroSequence = function findLongestZeroSequence(ipv6) {
var maxIndex = null;
var maxLength = 1;
var currStart = null;
var currLength = 0;
var index = 0;
for (; index < 8; index++) {
if (ipv6[index] !== 0) {
if (currLength > maxLength) {
maxIndex = currStart;
maxLength = currLength;
}
currStart = null;
currLength = 0;
} else {
if (currStart === null) currStart = index;
++currLength;
}
}
if (currLength > maxLength) {
maxIndex = currStart;
maxLength = currLength;
}
return maxIndex;
};
var serializeHost = function serializeHost(host) {
var result, index, compress, ignore0; // ipv4
if (typeof host == 'number') {
result = [];
for (index = 0; index < 4; index++) {
result.unshift(host % 256);
host = floor$9(host / 256);
}
return result.join('.'); // ipv6
} else if (typeof host == 'object') {
result = '';
compress = findLongestZeroSequence(host);
for (index = 0; index < 8; index++) {
if (ignore0 && host[index] === 0) continue;
if (ignore0) ignore0 = false;
if (compress === index) {
result += index ? ':' : '::';
ignore0 = true;
} else {
result += host[index].toString(16);
if (index < 7) result += ':';
}
}
return '[' + result + ']';
}
return host;
};
var C0ControlPercentEncodeSet = {};
var fragmentPercentEncodeSet = objectAssign({}, C0ControlPercentEncodeSet, {
' ': 1,
'"': 1,
'<': 1,
'>': 1,
'`': 1
});
var pathPercentEncodeSet = objectAssign({}, fragmentPercentEncodeSet, {
'#': 1,
'?': 1,
'{': 1,
'}': 1
});
var userinfoPercentEncodeSet = objectAssign({}, pathPercentEncodeSet, {
'/': 1,
':': 1,
';': 1,
'=': 1,
'@': 1,
'[': 1,
'\\': 1,
']': 1,
'^': 1,
'|': 1
});
var percentEncode = function percentEncode(char, set) {
var code = codeAt$1(char, 0);
return code > 0x20 && code < 0x7F && !has(set, char) ? char : encodeURIComponent(char);
};
var specialSchemes = {
ftp: 21,
file: null,
http: 80,
https: 443,
ws: 80,
wss: 443
};
var isSpecial = function isSpecial(url) {
return has(specialSchemes, url.scheme);
};
var includesCredentials = function includesCredentials(url) {
return url.username != '' || url.password != '';
};
var cannotHaveUsernamePasswordPort = function cannotHaveUsernamePasswordPort(url) {
return !url.host || url.cannotBeABaseURL || url.scheme == 'file';
};
var isWindowsDriveLetter = function isWindowsDriveLetter(string, normalized) {
var second;
return string.length == 2 && ALPHA.test(string.charAt(0)) && ((second = string.charAt(1)) == ':' || !normalized && second == '|');
};
var startsWithWindowsDriveLetter = function startsWithWindowsDriveLetter(string) {
var third;
return string.length > 1 && isWindowsDriveLetter(string.slice(0, 2)) && (string.length == 2 || (third = string.charAt(2)) === '/' || third === '\\' || third === '?' || third === '#');
};
var shortenURLsPath = function shortenURLsPath(url) {
var path = url.path;
var pathSize = path.length;
if (pathSize && (url.scheme != 'file' || pathSize != 1 || !isWindowsDriveLetter(path[0], true))) {
path.pop();
}
};
var isSingleDot = function isSingleDot(segment) {
return segment === '.' || segment.toLowerCase() === '%2e';
};
var isDoubleDot = function isDoubleDot(segment) {
segment = segment.toLowerCase();
return segment === '..' || segment === '%2e.' || segment === '.%2e' || segment === '%2e%2e';
}; // States:
var SCHEME_START = {};
var SCHEME = {};
var NO_SCHEME = {};
var SPECIAL_RELATIVE_OR_AUTHORITY = {};
var PATH_OR_AUTHORITY = {};
var RELATIVE = {};
var RELATIVE_SLASH = {};
var SPECIAL_AUTHORITY_SLASHES = {};
var SPECIAL_AUTHORITY_IGNORE_SLASHES = {};
var AUTHORITY = {};
var HOST = {};
var HOSTNAME = {};
var PORT = {};
var FILE = {};
var FILE_SLASH = {};
var FILE_HOST = {};
var PATH_START = {};
var PATH = {};
var CANNOT_BE_A_BASE_URL_PATH = {};
var QUERY = {};
var FRAGMENT = {}; // eslint-disable-next-line max-statements -- TODO
var parseURL = function parseURL(url, input, stateOverride, base) {
var state = stateOverride || SCHEME_START;
var pointer = 0;
var buffer = '';
var seenAt = false;
var seenBracket = false;
var seenPasswordToken = false;
var codePoints, char, bufferCodePoints, failure;
if (!stateOverride) {
url.scheme = '';
url.username = '';
url.password = '';
url.host = null;
url.port = null;
url.path = [];
url.query = null;
url.fragment = null;
url.cannotBeABaseURL = false;
input = input.replace(LEADING_AND_TRAILING_C0_CONTROL_OR_SPACE, '');
}
input = input.replace(TAB_AND_NEW_LINE, '');
codePoints = arrayFrom(input);
while (pointer <= codePoints.length) {
char = codePoints[pointer];
switch (state) {
case SCHEME_START:
if (char && ALPHA.test(char)) {
buffer += char.toLowerCase();
state = SCHEME;
} else if (!stateOverride) {
state = NO_SCHEME;
continue;
} else return INVALID_SCHEME;
break;
case SCHEME:
if (char && (ALPHANUMERIC.test(char) || char == '+' || char == '-' || char == '.')) {
buffer += char.toLowerCase();
} else if (char == ':') {
if (stateOverride && (isSpecial(url) != has(specialSchemes, buffer) || buffer == 'file' && (includesCredentials(url) || url.port !== null) || url.scheme == 'file' && !url.host)) return;
url.scheme = buffer;
if (stateOverride) {
if (isSpecial(url) && specialSchemes[url.scheme] == url.port) url.port = null;
return;
}
buffer = '';
if (url.scheme == 'file') {
state = FILE;
} else if (isSpecial(url) && base && base.scheme == url.scheme) {
state = SPECIAL_RELATIVE_OR_AUTHORITY;
} else if (isSpecial(url)) {
state = SPECIAL_AUTHORITY_SLASHES;
} else if (codePoints[pointer + 1] == '/') {
state = PATH_OR_AUTHORITY;
pointer++;
} else {
url.cannotBeABaseURL = true;
url.path.push('');
state = CANNOT_BE_A_BASE_URL_PATH;
}
} else if (!stateOverride) {
buffer = '';
state = NO_SCHEME;
pointer = 0;
continue;
} else return INVALID_SCHEME;
break;
case NO_SCHEME:
if (!base || base.cannotBeABaseURL && char != '#') return INVALID_SCHEME;
if (base.cannotBeABaseURL && char == '#') {
url.scheme = base.scheme;
url.path = base.path.slice();
url.query = base.query;
url.fragment = '';
url.cannotBeABaseURL = true;
state = FRAGMENT;
break;
}
state = base.scheme == 'file' ? FILE : RELATIVE;
continue;
case SPECIAL_RELATIVE_OR_AUTHORITY:
if (char == '/' && codePoints[pointer + 1] == '/') {
state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
pointer++;
} else {
state = RELATIVE;
continue;
}
break;
case PATH_OR_AUTHORITY:
if (char == '/') {
state = AUTHORITY;
break;
} else {
state = PATH;
continue;
}
case RELATIVE:
url.scheme = base.scheme;
if (char == EOF) {
url.username = base.username;
url.password = base.password;
url.host = base.host;
url.port = base.port;
url.path = base.path.slice();
url.query = base.query;
} else if (char == '/' || char == '\\' && isSpecial(url)) {
state = RELATIVE_SLASH;
} else if (char == '?') {
url.username = base.username;
url.password = base.password;
url.host = base.host;
url.port = base.port;
url.path = base.path.slice();
url.query = '';
state = QUERY;
} else if (char == '#') {
url.username = base.username;
url.password = base.password;
url.host = base.host;
url.port = base.port;
url.path = base.path.slice();
url.query = base.query;
url.fragment = '';
state = FRAGMENT;
} else {
url.username = base.username;
url.password = base.password;
url.host = base.host;
url.port = base.port;
url.path = base.path.slice();
url.path.pop();
state = PATH;
continue;
}
break;
case RELATIVE_SLASH:
if (isSpecial(url) && (char == '/' || char == '\\')) {
state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
} else if (char == '/') {
state = AUTHORITY;
} else {
url.username = base.username;
url.password = base.password;
url.host = base.host;
url.port = base.port;
state = PATH;
continue;
}
break;
case SPECIAL_AUTHORITY_SLASHES:
state = SPECIAL_AUTHORITY_IGNORE_SLASHES;
if (char != '/' || buffer.charAt(pointer + 1) != '/') continue;
pointer++;
break;
case SPECIAL_AUTHORITY_IGNORE_SLASHES:
if (char != '/' && char != '\\') {
state = AUTHORITY;
continue;
}
break;
case AUTHORITY:
if (char == '@') {
if (seenAt) buffer = '%40' + buffer;
seenAt = true;
bufferCodePoints = arrayFrom(buffer);
for (var i = 0; i < bufferCodePoints.length; i++) {
var codePoint = bufferCodePoints[i];
if (codePoint == ':' && !seenPasswordToken) {
seenPasswordToken = true;
continue;
}
var encodedCodePoints = percentEncode(codePoint, userinfoPercentEncodeSet);
if (seenPasswordToken) url.password += encodedCodePoints;else url.username += encodedCodePoints;
}
buffer = '';
} else if (char == EOF || char == '/' || char == '?' || char == '#' || char == '\\' && isSpecial(url)) {
if (seenAt && buffer == '') return INVALID_AUTHORITY;
pointer -= arrayFrom(buffer).length + 1;
buffer = '';
state = HOST;
} else buffer += char;
break;
case HOST:
case HOSTNAME:
if (stateOverride && url.scheme == 'file') {
state = FILE_HOST;
continue;
} else if (char == ':' && !seenBracket) {
if (buffer == '') return INVALID_HOST;
failure = parseHost(url, buffer);
if (failure) return failure;
buffer = '';
state = PORT;
if (stateOverride == HOSTNAME) return;
} else if (char == EOF || char == '/' || char == '?' || char == '#' || char == '\\' && isSpecial(url)) {
if (isSpecial(url) && buffer == '') return INVALID_HOST;
if (stateOverride && buffer == '' && (includesCredentials(url) || url.port !== null)) return;
failure = parseHost(url, buffer);
if (failure) return failure;
buffer = '';
state = PATH_START;
if (stateOverride) return;
continue;
} else {
if (char == '[') seenBracket = true;else if (char == ']') seenBracket = false;
buffer += char;
}
break;
case PORT:
if (DIGIT.test(char)) {
buffer += char;
} else if (char == EOF || char == '/' || char == '?' || char == '#' || char == '\\' && isSpecial(url) || stateOverride) {
if (buffer != '') {
var port = parseInt(buffer, 10);
if (port > 0xFFFF) return INVALID_PORT;
url.port = isSpecial(url) && port === specialSchemes[url.scheme] ? null : port;
buffer = '';
}
if (stateOverride) return;
state = PATH_START;
continue;
} else return INVALID_PORT;
break;
case FILE:
url.scheme = 'file';
if (char == '/' || char == '\\') state = FILE_SLASH;else if (base && base.scheme == 'file') {
if (char == EOF) {
url.host = base.host;
url.path = base.path.slice();
url.query = base.query;
} else if (char == '?') {
url.host = base.host;
url.path = base.path.slice();
url.query = '';
state = QUERY;
} else if (char == '#') {
url.host = base.host;
url.path = base.path.slice();
url.query = base.query;
url.fragment = '';
state = FRAGMENT;
} else {
if (!startsWithWindowsDriveLetter(codePoints.slice(pointer).join(''))) {
url.host = base.host;
url.path = base.path.slice();
shortenURLsPath(url);
}
state = PATH;
continue;
}
} else {
state = PATH;
continue;
}
break;
case FILE_SLASH:
if (char == '/' || char == '\\') {
state = FILE_HOST;
break;
}
if (base && base.scheme == 'file' && !startsWithWindowsDriveLetter(codePoints.slice(pointer).join(''))) {
if (isWindowsDriveLetter(base.path[0], true)) url.path.push(base.path[0]);else url.host = base.host;
}
state = PATH;
continue;
case FILE_HOST:
if (char == EOF || char == '/' || char == '\\' || char == '?' || char == '#') {
if (!stateOverride && isWindowsDriveLetter(buffer)) {
state = PATH;
} else if (buffer == '') {
url.host = '';
if (stateOverride) return;
state = PATH_START;
} else {
failure = parseHost(url, buffer);
if (failure) return failure;
if (url.host == 'localhost') url.host = '';
if (stateOverride) return;
buffer = '';
state = PATH_START;
}
continue;
} else buffer += char;
break;
case PATH_START:
if (isSpecial(url)) {
state = PATH;
if (char != '/' && char != '\\') continue;
} else if (!stateOverride && char == '?') {
url.query = '';
state = QUERY;
} else if (!stateOverride && char == '#') {
url.fragment = '';
state = FRAGMENT;
} else if (char != EOF) {
state = PATH;
if (char != '/') continue;
}
break;
case PATH:
if (char == EOF || char == '/' || char == '\\' && isSpecial(url) || !stateOverride && (char == '?' || char == '#')) {
if (isDoubleDot(buffer)) {
shortenURLsPath(url);
if (char != '/' && !(char == '\\' && isSpecial(url))) {
url.path.push('');
}
} else if (isSingleDot(buffer)) {
if (char != '/' && !(char == '\\' && isSpecial(url))) {
url.path.push('');
}
} else {
if (url.scheme == 'file' && !url.path.length && isWindowsDriveLetter(buffer)) {
if (url.host) url.host = '';
buffer = buffer.charAt(0) + ':'; // normalize windows drive letter
}
url.path.push(buffer);
}
buffer = '';
if (url.scheme == 'file' && (char == EOF || char == '?' || char == '#')) {
while (url.path.length > 1 && url.path[0] === '') {
url.path.shift();
}
}
if (char == '?') {
url.query = '';
state = QUERY;
} else if (char == '#') {
url.fragment = '';
state = FRAGMENT;
}
} else {
buffer += percentEncode(char, pathPercentEncodeSet);
}
break;
case CANNOT_BE_A_BASE_URL_PATH:
if (char == '?') {
url.query = '';
state = QUERY;
} else if (char == '#') {
url.fragment = '';
state = FRAGMENT;
} else if (char != EOF) {
url.path[0] += percentEncode(char, C0ControlPercentEncodeSet);
}
break;
case QUERY:
if (!stateOverride && char == '#') {
url.fragment = '';
state = FRAGMENT;
} else if (char != EOF) {
if (char == "'" && isSpecial(url)) url.query += '%27';else if (char == '#') url.query += '%23';else url.query += percentEncode(char, C0ControlPercentEncodeSet);
}
break;
case FRAGMENT:
if (char != EOF) url.fragment += percentEncode(char, fragmentPercentEncodeSet);
break;
}
pointer++;
}
}; // `URL` constructor
// https://url.spec.whatwg.org/#url-class
var URLConstructor = function URL(url
/* , base */
) {
var that = anInstance(this, URLConstructor, 'URL');
var base = arguments.length > 1 ? arguments[1] : undefined;
var urlString = String(url);
var state = setInternalState$a(that, {
type: 'URL'
});
var baseState, failure;
if (base !== undefined) {
if (base instanceof URLConstructor) baseState = getInternalURLState(base);else {
failure = parseURL(baseState = {}, String(base));
if (failure) throw TypeError(failure);
}
}
failure = parseURL(state, urlString, null, baseState);
if (failure) throw TypeError(failure);
var searchParams = state.searchParams = new URLSearchParams$1();
var searchParamsState = getInternalSearchParamsState(searchParams);
searchParamsState.updateSearchParams(state.query);
searchParamsState.updateURL = function () {
state.query = String(searchParams) || null;
};
if (!descriptors) {
that.href = serializeURL.call(that);
that.origin = getOrigin.call(that);
that.protocol = getProtocol.call(that);
that.username = getUsername.call(that);
that.password = getPassword.call(that);
that.host = getHost.call(that);
that.hostname = getHostname.call(that);
that.port = getPort.call(that);
that.pathname = getPathname.call(that);
that.search = getSearch.call(that);
that.searchParams = getSearchParams.call(that);
that.hash = getHash.call(that);
}
};
var URLPrototype = URLConstructor.prototype;
var serializeURL = function serializeURL() {
var url = getInternalURLState(this);
var scheme = url.scheme;
var username = url.username;
var password = url.password;
var host = url.host;
var port = url.port;
var path = url.path;
var query = url.query;
var fragment = url.fragment;
var output = scheme + ':';
if (host !== null) {
output += '//';
if (includesCredentials(url)) {
output += username + (password ? ':' + password : '') + '@';
}
output += serializeHost(host);
if (port !== null) output += ':' + port;
} else if (scheme == 'file') output += '//';
output += url.cannotBeABaseURL ? path[0] : path.length ? '/' + path.join('/') : '';
if (query !== null) output += '?' + query;
if (fragment !== null) output += '#' + fragment;
return output;
};
var getOrigin = function getOrigin() {
var url = getInternalURLState(this);
var scheme = url.scheme;
var port = url.port;
if (scheme == 'blob') try {
return new URL(scheme.path[0]).origin;
} catch (error) {
return 'null';
}
if (scheme == 'file' || !isSpecial(url)) return 'null';
return scheme + '://' + serializeHost(url.host) + (port !== null ? ':' + port : '');
};
var getProtocol = function getProtocol() {
return getInternalURLState(this).scheme + ':';
};
var getUsername = function getUsername() {
return getInternalURLState(this).username;
};
var getPassword = function getPassword() {
return getInternalURLState(this).password;
};
var getHost = function getHost() {
var url = getInternalURLState(this);
var host = url.host;
var port = url.port;
return host === null ? '' : port === null ? serializeHost(host) : serializeHost(host) + ':' + port;
};
var getHostname = function getHostname() {
var host = getInternalURLState(this).host;
return host === null ? '' : serializeHost(host);
};
var getPort = function getPort() {
var port = getInternalURLState(this).port;
return port === null ? '' : String(port);
};
var getPathname = function getPathname() {
var url = getInternalURLState(this);
var path = url.path;
return url.cannotBeABaseURL ? path[0] : path.length ? '/' + path.join('/') : '';
};
var getSearch = function getSearch() {
var query = getInternalURLState(this).query;
return query ? '?' + query : '';
};
var getSearchParams = function getSearchParams() {
return getInternalURLState(this).searchParams;
};
var getHash = function getHash() {
var fragment = getInternalURLState(this).fragment;
return fragment ? '#' + fragment : '';
};
var accessorDescriptor = function accessorDescriptor(getter, setter) {
return {
get: getter,
set: setter,
configurable: true,
enumerable: true
};
};
if (descriptors) {
objectDefineProperties(URLPrototype, {
// `URL.prototype.href` accessors pair
// https://url.spec.whatwg.org/#dom-url-href
href: accessorDescriptor(serializeURL, function (href) {
var url = getInternalURLState(this);
var urlString = String(href);
var failure = parseURL(url, urlString);
if (failure) throw TypeError(failure);
getInternalSearchParamsState(url.searchParams).updateSearchParams(url.query);
}),
// `URL.prototype.origin` getter
// https://url.spec.whatwg.org/#dom-url-origin
origin: accessorDescriptor(getOrigin),
// `URL.prototype.protocol` accessors pair
// https://url.spec.whatwg.org/#dom-url-protocol
protocol: accessorDescriptor(getProtocol, function (protocol) {
var url = getInternalURLState(this);
parseURL(url, String(protocol) + ':', SCHEME_START);
}),
// `URL.prototype.username` accessors pair
// https://url.spec.whatwg.org/#dom-url-username
username: accessorDescriptor(getUsername, function (username) {
var url = getInternalURLState(this);
var codePoints = arrayFrom(String(username));
if (cannotHaveUsernamePasswordPort(url)) return;
url.username = '';
for (var i = 0; i < codePoints.length; i++) {
url.username += percentEncode(codePoints[i], userinfoPercentEncodeSet);
}
}),
// `URL.prototype.password` accessors pair
// https://url.spec.whatwg.org/#dom-url-password
password: accessorDescriptor(getPassword, function (password) {
var url = getInternalURLState(this);
var codePoints = arrayFrom(String(password));
if (cannotHaveUsernamePasswordPort(url)) return;
url.password = '';
for (var i = 0; i < codePoints.length; i++) {
url.password += percentEncode(codePoints[i], userinfoPercentEncodeSet);
}
}),
// `URL.prototype.host` accessors pair
// https://url.spec.whatwg.org/#dom-url-host
host: accessorDescriptor(getHost, function (host) {
var url = getInternalURLState(this);
if (url.cannotBeABaseURL) return;
parseURL(url, String(host), HOST);
}),
// `URL.prototype.hostname` accessors pair
// https://url.spec.whatwg.org/#dom-url-hostname
hostname: accessorDescriptor(getHostname, function (hostname) {
var url = getInternalURLState(this);
if (url.cannotBeABaseURL) return;
parseURL(url, String(hostname), HOSTNAME);
}),
// `URL.prototype.port` accessors pair
// https://url.spec.whatwg.org/#dom-url-port
port: accessorDescriptor(getPort, function (port) {
var url = getInternalURLState(this);
if (cannotHaveUsernamePasswordPort(url)) return;
port = String(port);
if (port == '') url.port = null;else parseURL(url, port, PORT);
}),
// `URL.prototype.pathname` accessors pair
// https://url.spec.whatwg.org/#dom-url-pathname
pathname: accessorDescriptor(getPathname, function (pathname) {
var url = getInternalURLState(this);
if (url.cannotBeABaseURL) return;
url.path = [];
parseURL(url, pathname + '', PATH_START);
}),
// `URL.prototype.search` accessors pair
// https://url.spec.whatwg.org/#dom-url-search
search: accessorDescriptor(getSearch, function (search) {
var url = getInternalURLState(this);
search = String(search);
if (search == '') {
url.query = null;
} else {
if ('?' == search.charAt(0)) search = search.slice(1);
url.query = '';
parseURL(url, search, QUERY);
}
getInternalSearchParamsState(url.searchParams).updateSearchParams(url.query);
}),
// `URL.prototype.searchParams` getter
// https://url.spec.whatwg.org/#dom-url-searchparams
searchParams: accessorDescriptor(getSearchParams),
// `URL.prototype.hash` accessors pair
// https://url.spec.whatwg.org/#dom-url-hash
hash: accessorDescriptor(getHash, function (hash) {
var url = getInternalURLState(this);
hash = String(hash);
if (hash == '') {
url.fragment = null;
return;
}
if ('#' == hash.charAt(0)) hash = hash.slice(1);
url.fragment = '';
parseURL(url, hash, FRAGMENT);
})
});
} // `URL.prototype.toJSON` method
// https://url.spec.whatwg.org/#dom-url-tojson
redefine(URLPrototype, 'toJSON', function toJSON() {
return serializeURL.call(this);
}, {
enumerable: true
}); // `URL.prototype.toString` method
// https://url.spec.whatwg.org/#URL-stringification-behavior
redefine(URLPrototype, 'toString', function toString() {
return serializeURL.call(this);
}, {
enumerable: true
});
if (NativeURL) {
var nativeCreateObjectURL = NativeURL.createObjectURL;
var nativeRevokeObjectURL = NativeURL.revokeObjectURL; // `URL.createObjectURL` method
// https://developer.mozilla.org/en-US/docs/Web/API/URL/createObjectURL
// eslint-disable-next-line no-unused-vars -- required for `.length`
if (nativeCreateObjectURL) redefine(URLConstructor, 'createObjectURL', function createObjectURL(blob) {
return nativeCreateObjectURL.apply(NativeURL, arguments);
}); // `URL.revokeObjectURL` method
// https://developer.mozilla.org/en-US/docs/Web/API/URL/revokeObjectURL
// eslint-disable-next-line no-unused-vars -- required for `.length`
if (nativeRevokeObjectURL) redefine(URLConstructor, 'revokeObjectURL', function revokeObjectURL(url) {
return nativeRevokeObjectURL.apply(NativeURL, arguments);
});
}
setToStringTag(URLConstructor, 'URL');
_export({
global: true,
forced: !nativeUrl,
sham: !descriptors
}, {
URL: URLConstructor
});
var web_url = {};
'use strict'; // `URL.prototype.toJSON` method
// https://url.spec.whatwg.org/#dom-url-tojson
_export({
target: 'URL',
proto: true,
enumerable: true
}, {
toJSON: function toJSON() {
return URL.prototype.toString.call(this);
}
});
var web_url_toJson = {};
var web = path;
var stable = path;
var runtime_1 = createCommonjsModule(function (module) {
/**
* Copyright (c) 2014-present, Facebook, Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
var runtime = function (exports) {
"use strict";
var Op = Object.prototype;
var hasOwn = Op.hasOwnProperty;
var undefined$1; // More compressible than void 0.
var $Symbol = typeof Symbol === "function" ? Symbol : {};
var iteratorSymbol = $Symbol.iterator || "@@iterator";
var asyncIteratorSymbol = $Symbol.asyncIterator || "@@asyncIterator";
var toStringTagSymbol = $Symbol.toStringTag || "@@toStringTag";
function define(obj, key, value) {
Object.defineProperty(obj, key, {
value: value,
enumerable: true,
configurable: true,
writable: true
});
return obj[key];
}
try {
// IE 8 has a broken Object.defineProperty that only works on DOM objects.
define({}, "");
} catch (err) {
define = function define(obj, key, value) {
return obj[key] = value;
};
}
function wrap(innerFn, outerFn, self, tryLocsList) {
// If outerFn provided and outerFn.prototype is a Generator, then outerFn.prototype instanceof Generator.
var protoGenerator = outerFn && outerFn.prototype instanceof Generator ? outerFn : Generator;
var generator = Object.create(protoGenerator.prototype);
var context = new Context(tryLocsList || []); // The ._invoke method unifies the implementations of the .next,
// .throw, and .return methods.
generator._invoke = makeInvokeMethod(innerFn, self, context);
return generator;
}
exports.wrap = wrap; // Try/catch helper to minimize deoptimizations. Returns a completion
// record like context.tryEntries[i].completion. This interface could
// have been (and was previously) designed to take a closure to be
// invoked without arguments, but in all the cases we care about we
// already have an existing method we want to call, so there's no need
// to create a new function object. We can even get away with assuming
// the method takes exactly one argument, since that happens to be true
// in every case, so we don't have to touch the arguments object. The
// only additional allocation required is the completion record, which
// has a stable shape and so hopefully should be cheap to allocate.
function tryCatch(fn, obj, arg) {
try {
return {
type: "normal",
arg: fn.call(obj, arg)
};
} catch (err) {
return {
type: "throw",
arg: err
};
}
}
var GenStateSuspendedStart = "suspendedStart";
var GenStateSuspendedYield = "suspendedYield";
var GenStateExecuting = "executing";
var GenStateCompleted = "completed"; // Returning this object from the innerFn has the same effect as
// breaking out of the dispatch switch statement.
var ContinueSentinel = {}; // Dummy constructor functions that we use as the .constructor and
// .constructor.prototype properties for functions that return Generator
// objects. For full spec compliance, you may wish to configure your
// minifier not to mangle the names of these two functions.
function Generator() {}
function GeneratorFunction() {}
function GeneratorFunctionPrototype() {} // This is a polyfill for %IteratorPrototype% for environments that
// don't natively support it.
var IteratorPrototype = {};
IteratorPrototype[iteratorSymbol] = function () {
return this;
};
var getProto = Object.getPrototypeOf;
var NativeIteratorPrototype = getProto && getProto(getProto(values([])));
if (NativeIteratorPrototype && NativeIteratorPrototype !== Op && hasOwn.call(NativeIteratorPrototype, iteratorSymbol)) {
// This environment has a native %IteratorPrototype%; use it instead
// of the polyfill.
IteratorPrototype = NativeIteratorPrototype;
}
var Gp = GeneratorFunctionPrototype.prototype = Generator.prototype = Object.create(IteratorPrototype);
GeneratorFunction.prototype = Gp.constructor = GeneratorFunctionPrototype;
GeneratorFunctionPrototype.constructor = GeneratorFunction;
GeneratorFunction.displayName = define(GeneratorFunctionPrototype, toStringTagSymbol, "GeneratorFunction"); // Helper for defining the .next, .throw, and .return methods of the
// Iterator interface in terms of a single ._invoke method.
function defineIteratorMethods(prototype) {
["next", "throw", "return"].forEach(function (method) {
define(prototype, method, function (arg) {
return this._invoke(method, arg);
});
});
}
exports.isGeneratorFunction = function (genFun) {
var ctor = typeof genFun === "function" && genFun.constructor;
return ctor ? ctor === GeneratorFunction || // For the native GeneratorFunction constructor, the best we can
// do is to check its .name property.
(ctor.displayName || ctor.name) === "GeneratorFunction" : false;
};
exports.mark = function (genFun) {
if (Object.setPrototypeOf) {
Object.setPrototypeOf(genFun, GeneratorFunctionPrototype);
} else {
genFun.__proto__ = GeneratorFunctionPrototype;
define(genFun, toStringTagSymbol, "GeneratorFunction");
}
genFun.prototype = Object.create(Gp);
return genFun;
}; // Within the body of any async function, `await x` is transformed to
// `yield regeneratorRuntime.awrap(x)`, so that the runtime can test
// `hasOwn.call(value, "__await")` to determine if the yielded value is
// meant to be awaited.
exports.awrap = function (arg) {
return {
__await: arg
};
};
function AsyncIterator(generator, PromiseImpl) {
function invoke(method, arg, resolve, reject) {
var record = tryCatch(generator[method], generator, arg);
if (record.type === "throw") {
reject(record.arg);
} else {
var result = record.arg;
var value = result.value;
if (value && typeof value === "object" && hasOwn.call(value, "__await")) {
return PromiseImpl.resolve(value.__await).then(function (value) {
invoke("next", value, resolve, reject);
}, function (err) {
invoke("throw", err, resolve, reject);
});
}
return PromiseImpl.resolve(value).then(function (unwrapped) {
// When a yielded Promise is resolved, its final value becomes
// the .value of the Promise<{value,done}> result for the
// current iteration.
result.value = unwrapped;
resolve(result);
}, function (error) {
// If a rejected Promise was yielded, throw the rejection back
// into the async generator function so it can be handled there.
return invoke("throw", error, resolve, reject);
});
}
}
var previousPromise;
function enqueue(method, arg) {
function callInvokeWithMethodAndArg() {
return new PromiseImpl(function (resolve, reject) {
invoke(method, arg, resolve, reject);
});
}
return previousPromise = // If enqueue has been called before, then we want to wait until
// all previous Promises have been resolved before calling invoke,
// so that results are always delivered in the correct order. If
// enqueue has not been called before, then it is important to
// call invoke immediately, without waiting on a callback to fire,
// so that the async generator function has the opportunity to do
// any necessary setup in a predictable way. This predictability
// is why the Promise constructor synchronously invokes its
// executor callback, and why async functions synchronously
// execute code before the first await. Since we implement simple
// async functions in terms of async generators, it is especially
// important to get this right, even though it requires care.
previousPromise ? previousPromise.then(callInvokeWithMethodAndArg, // Avoid propagating failures to Promises returned by later
// invocations of the iterator.
callInvokeWithMethodAndArg) : callInvokeWithMethodAndArg();
} // Define the unified helper method that is used to implement .next,
// .throw, and .return (see defineIteratorMethods).
this._invoke = enqueue;
}
defineIteratorMethods(AsyncIterator.prototype);
AsyncIterator.prototype[asyncIteratorSymbol] = function () {
return this;
};
exports.AsyncIterator = AsyncIterator; // Note that simple async functions are implemented on top of
// AsyncIterator objects; they just return a Promise for the value of
// the final result produced by the iterator.
exports.async = function (innerFn, outerFn, self, tryLocsList, PromiseImpl) {
if (PromiseImpl === void 0) PromiseImpl = Promise;
var iter = new AsyncIterator(wrap(innerFn, outerFn, self, tryLocsList), PromiseImpl);
return exports.isGeneratorFunction(outerFn) ? iter // If outerFn is a generator, return the full iterator.
: iter.next().then(function (result) {
return result.done ? result.value : iter.next();
});
};
function makeInvokeMethod(innerFn, self, context) {
var state = GenStateSuspendedStart;
return function invoke(method, arg) {
if (state === GenStateExecuting) {
throw new Error("Generator is already running");
}
if (state === GenStateCompleted) {
if (method === "throw") {
throw arg;
} // Be forgiving, per 25.3.3.3.3 of the spec:
// https://people.mozilla.org/~jorendorff/es6-draft.html#sec-generatorresume
return doneResult();
}
context.method = method;
context.arg = arg;
while (true) {
var delegate = context.delegate;
if (delegate) {
var delegateResult = maybeInvokeDelegate(delegate, context);
if (delegateResult) {
if (delegateResult === ContinueSentinel) continue;
return delegateResult;
}
}
if (context.method === "next") {
// Setting context._sent for legacy support of Babel's
// function.sent implementation.
context.sent = context._sent = context.arg;
} else if (context.method === "throw") {
if (state === GenStateSuspendedStart) {
state = GenStateCompleted;
throw context.arg;
}
context.dispatchException(context.arg);
} else if (context.method === "return") {
context.abrupt("return", context.arg);
}
state = GenStateExecuting;
var record = tryCatch(innerFn, self, context);
if (record.type === "normal") {
// If an exception is thrown from innerFn, we leave state ===
// GenStateExecuting and loop back for another invocation.
state = context.done ? GenStateCompleted : GenStateSuspendedYield;
if (record.arg === ContinueSentinel) {
continue;
}
return {
value: record.arg,
done: context.done
};
} else if (record.type === "throw") {
state = GenStateCompleted; // Dispatch the exception by looping back around to the
// context.dispatchException(context.arg) call above.
context.method = "throw";
context.arg = record.arg;
}
}
};
} // Call delegate.iterator[context.method](context.arg) and handle the
// result, either by returning a { value, done } result from the
// delegate iterator, or by modifying context.method and context.arg,
// setting context.delegate to null, and returning the ContinueSentinel.
function maybeInvokeDelegate(delegate, context) {
var method = delegate.iterator[context.method];
if (method === undefined$1) {
// A .throw or .return when the delegate iterator has no .throw
// method always terminates the yield* loop.
context.delegate = null;
if (context.method === "throw") {
// Note: ["return"] must be used for ES3 parsing compatibility.
if (delegate.iterator["return"]) {
// If the delegate iterator has a return method, give it a
// chance to clean up.
context.method = "return";
context.arg = undefined$1;
maybeInvokeDelegate(delegate, context);
if (context.method === "throw") {
// If maybeInvokeDelegate(context) changed context.method from
// "return" to "throw", let that override the TypeError below.
return ContinueSentinel;
}
}
context.method = "throw";
context.arg = new TypeError("The iterator does not provide a 'throw' method");
}
return ContinueSentinel;
}
var record = tryCatch(method, delegate.iterator, context.arg);
if (record.type === "throw") {
context.method = "throw";
context.arg = record.arg;
context.delegate = null;
return ContinueSentinel;
}
var info = record.arg;
if (!info) {
context.method = "throw";
context.arg = new TypeError("iterator result is not an object");
context.delegate = null;
return ContinueSentinel;
}
if (info.done) {
// Assign the result of the finished delegate to the temporary
// variable specified by delegate.resultName (see delegateYield).
context[delegate.resultName] = info.value; // Resume execution at the desired location (see delegateYield).
context.next = delegate.nextLoc; // If context.method was "throw" but the delegate handled the
// exception, let the outer generator proceed normally. If
// context.method was "next", forget context.arg since it has been
// "consumed" by the delegate iterator. If context.method was
// "return", allow the original .return call to continue in the
// outer generator.
if (context.method !== "return") {
context.method = "next";
context.arg = undefined$1;
}
} else {
// Re-yield the result returned by the delegate method.
return info;
} // The delegate iterator is finished, so forget it and continue with
// the outer generator.
context.delegate = null;
return ContinueSentinel;
} // Define Generator.prototype.{next,throw,return} in terms of the
// unified ._invoke helper method.
defineIteratorMethods(Gp);
define(Gp, toStringTagSymbol, "Generator"); // A Generator should always return itself as the iterator object when the
// @@iterator function is called on it. Some browsers' implementations of the
// iterator prototype chain incorrectly implement this, causing the Generator
// object to not be returned from this call. This ensures that doesn't happen.
// See https://github.com/facebook/regenerator/issues/274 for more details.
Gp[iteratorSymbol] = function () {
return this;
};
Gp.toString = function () {
return "[object Generator]";
};
function pushTryEntry(locs) {
var entry = {
tryLoc: locs[0]
};
if (1 in locs) {
entry.catchLoc = locs[1];
}
if (2 in locs) {
entry.finallyLoc = locs[2];
entry.afterLoc = locs[3];
}
this.tryEntries.push(entry);
}
function resetTryEntry(entry) {
var record = entry.completion || {};
record.type = "normal";
delete record.arg;
entry.completion = record;
}
function Context(tryLocsList) {
// The root entry object (effectively a try statement without a catch
// or a finally block) gives us a place to store values thrown from
// locations where there is no enclosing try statement.
this.tryEntries = [{
tryLoc: "root"
}];
tryLocsList.forEach(pushTryEntry, this);
this.reset(true);
}
exports.keys = function (object) {
var keys = [];
for (var key in object) {
keys.push(key);
}
keys.reverse(); // Rather than returning an object with a next method, we keep
// things simple and return the next function itself.
return function next() {
while (keys.length) {
var key = keys.pop();
if (key in object) {
next.value = key;
next.done = false;
return next;
}
} // To avoid creating an additional object, we just hang the .value
// and .done properties off the next function object itself. This
// also ensures that the minifier will not anonymize the function.
next.done = true;
return next;
};
};
function values(iterable) {
if (iterable) {
var iteratorMethod = iterable[iteratorSymbol];
if (iteratorMethod) {
return iteratorMethod.call(iterable);
}
if (typeof iterable.next === "function") {
return iterable;
}
if (!isNaN(iterable.length)) {
var i = -1,
next = function next() {
while (++i < iterable.length) {
if (hasOwn.call(iterable, i)) {
next.value = iterable[i];
next.done = false;
return next;
}
}
next.value = undefined$1;
next.done = true;
return next;
};
return next.next = next;
}
} // Return an iterator with no values.
return {
next: doneResult
};
}
exports.values = values;
function doneResult() {
return {
value: undefined$1,
done: true
};
}
Context.prototype = {
constructor: Context,
reset: function reset(skipTempReset) {
this.prev = 0;
this.next = 0; // Resetting context._sent for legacy support of Babel's
// function.sent implementation.
this.sent = this._sent = undefined$1;
this.done = false;
this.delegate = null;
this.method = "next";
this.arg = undefined$1;
this.tryEntries.forEach(resetTryEntry);
if (!skipTempReset) {
for (var name in this) {
// Not sure about the optimal order of these conditions:
if (name.charAt(0) === "t" && hasOwn.call(this, name) && !isNaN(+name.slice(1))) {
this[name] = undefined$1;
}
}
}
},
stop: function stop() {
this.done = true;
var rootEntry = this.tryEntries[0];
var rootRecord = rootEntry.completion;
if (rootRecord.type === "throw") {
throw rootRecord.arg;
}
return this.rval;
},
dispatchException: function dispatchException(exception) {
if (this.done) {
throw exception;
}
var context = this;
function handle(loc, caught) {
record.type = "throw";
record.arg = exception;
context.next = loc;
if (caught) {
// If the dispatched exception was caught by a catch block,
// then let that catch block handle the exception normally.
context.method = "next";
context.arg = undefined$1;
}
return !!caught;
}
for (var i = this.tryEntries.length - 1; i >= 0; --i) {
var entry = this.tryEntries[i];
var record = entry.completion;
if (entry.tryLoc === "root") {
// Exception thrown outside of any try block that could handle
// it, so set the completion value of the entire function to
// throw the exception.
return handle("end");
}
if (entry.tryLoc <= this.prev) {
var hasCatch = hasOwn.call(entry, "catchLoc");
var hasFinally = hasOwn.call(entry, "finallyLoc");
if (hasCatch && hasFinally) {
if (this.prev < entry.catchLoc) {
return handle(entry.catchLoc, true);
} else if (this.prev < entry.finallyLoc) {
return handle(entry.finallyLoc);
}
} else if (hasCatch) {
if (this.prev < entry.catchLoc) {
return handle(entry.catchLoc, true);
}
} else if (hasFinally) {
if (this.prev < entry.finallyLoc) {
return handle(entry.finallyLoc);
}
} else {
throw new Error("try statement without catch or finally");
}
}
}
},
abrupt: function abrupt(type, arg) {
for (var i = this.tryEntries.length - 1; i >= 0; --i) {
var entry = this.tryEntries[i];
if (entry.tryLoc <= this.prev && hasOwn.call(entry, "finallyLoc") && this.prev < entry.finallyLoc) {
var finallyEntry = entry;
break;
}
}
if (finallyEntry && (type === "break" || type === "continue") && finallyEntry.tryLoc <= arg && arg <= finallyEntry.finallyLoc) {
// Ignore the finally entry if control is not jumping to a
// location outside the try/catch block.
finallyEntry = null;
}
var record = finallyEntry ? finallyEntry.completion : {};
record.type = type;
record.arg = arg;
if (finallyEntry) {
this.method = "next";
this.next = finallyEntry.finallyLoc;
return ContinueSentinel;
}
return this.complete(record);
},
complete: function complete(record, afterLoc) {
if (record.type === "throw") {
throw record.arg;
}
if (record.type === "break" || record.type === "continue") {
this.next = record.arg;
} else if (record.type === "return") {
this.rval = this.arg = record.arg;
this.method = "return";
this.next = "end";
} else if (record.type === "normal" && afterLoc) {
this.next = afterLoc;
}
return ContinueSentinel;
},
finish: function finish(finallyLoc) {
for (var i = this.tryEntries.length - 1; i >= 0; --i) {
var entry = this.tryEntries[i];
if (entry.finallyLoc === finallyLoc) {
this.complete(entry.completion, entry.afterLoc);
resetTryEntry(entry);
return ContinueSentinel;
}
}
},
"catch": function _catch(tryLoc) {
for (var i = this.tryEntries.length - 1; i >= 0; --i) {
var entry = this.tryEntries[i];
if (entry.tryLoc === tryLoc) {
var record = entry.completion;
if (record.type === "throw") {
var thrown = record.arg;
resetTryEntry(entry);
}
return thrown;
}
} // The context.catch method must only be called with a location
// argument that corresponds to a known catch block.
throw new Error("illegal catch attempt");
},
delegateYield: function delegateYield(iterable, resultName, nextLoc) {
this.delegate = {
iterator: values(iterable),
resultName: resultName,
nextLoc: nextLoc
};
if (this.method === "next") {
// Deliberately forget the last sent value so that we don't
// accidentally pass it on to the delegate.
this.arg = undefined$1;
}
return ContinueSentinel;
}
}; // Regardless of whether this script is executing as a CommonJS module
// or not, return the runtime object so that we can declare the variable
// regeneratorRuntime in the outer scope, which allows this module to be
// injected easily by `bin/regenerator --include-runtime script.js`.
return exports;
}( // If this script is executing as a CommonJS module, use module.exports
// as the regeneratorRuntime namespace. Otherwise create a new empty
// object. Either way, the resulting object will be used to initialize
// the regeneratorRuntime variable at the top of this file.
'object' === "object" ? module.exports : {});
try {
regeneratorRuntime = runtime;
} catch (accidentalStrictMode) {
// This module should not be running in strict mode, so the above
// assignment should always work unless something is misconfigured. Just
// in case runtime.js accidentally runs in strict mode, we can escape
// strict mode using a global Function call. This could conceivably fail
// if a Content Security Policy forbids using Function, but in that case
// the proper solution is to fix the accidental strict mode problem. If
// you've misconfigured your bundler to force strict mode and applied a
// CSP to forbid Function, and you're not willing to fix either of those
// problems, please detail your unique predicament in a GitHub issue.
Function("r", "regeneratorRuntime = r")(runtime);
}
});
function _typeof(obj) {
"@babel/helpers - typeof";
if (typeof Symbol === "function" && typeof Symbol.iterator === "symbol") {
_typeof = function (obj) {
return typeof obj;
};
} else {
_typeof = function (obj) {
return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj;
};
}
return _typeof(obj);
}
var REACT_ELEMENT_TYPE;
function _jsx(type, props, key, children) {
if (!REACT_ELEMENT_TYPE) {
REACT_ELEMENT_TYPE = typeof Symbol === "function" && Symbol["for"] && Symbol["for"]("react.element") || 0xeac7;
}
var defaultProps = type && type.defaultProps;
var childrenLength = arguments.length - 3;
if (!props && childrenLength !== 0) {
props = {
children: void 0
};
}
if (childrenLength === 1) {
props.children = children;
} else if (childrenLength > 1) {
var childArray = new Array(childrenLength);
for (var i = 0; i < childrenLength; i++) {
childArray[i] = arguments[i + 3];
}
props.children = childArray;
}
if (props && defaultProps) {
for (var propName in defaultProps) {
if (props[propName] === void 0) {
props[propName] = defaultProps[propName];
}
}
} else if (!props) {
props = defaultProps || {};
}
return {
$$typeof: REACT_ELEMENT_TYPE,
type: type,
key: key === undefined ? null : '' + key,
ref: null,
props: props,
_owner: null
};
}
function _asyncIterator(iterable) {
var method;
if (typeof Symbol !== "undefined") {
if (Symbol.asyncIterator) {
method = iterable[Symbol.asyncIterator];
if (method != null) return method.call(iterable);
}
if (Symbol.iterator) {
method = iterable[Symbol.iterator];
if (method != null) return method.call(iterable);
}
}
throw new TypeError("Object is not async iterable");
}
function _AwaitValue(value) {
this.wrapped = value;
}
function _AsyncGenerator(gen) {
var front, back;
function send(key, arg) {
return new Promise(function (resolve, reject) {
var request = {
key: key,
arg: arg,
resolve: resolve,
reject: reject,
next: null
};
if (back) {
back = back.next = request;
} else {
front = back = request;
resume(key, arg);
}
});
}
function resume(key, arg) {
try {
var result = gen[key](arg);
var value = result.value;
var wrappedAwait = value instanceof _AwaitValue;
Promise.resolve(wrappedAwait ? value.wrapped : value).then(function (arg) {
if (wrappedAwait) {
resume(key === "return" ? "return" : "next", arg);
return;
}
settle(result.done ? "return" : "normal", arg);
}, function (err) {
resume("throw", err);
});
} catch (err) {
settle("throw", err);
}
}
function settle(type, value) {
switch (type) {
case "return":
front.resolve({
value: value,
done: true
});
break;
case "throw":
front.reject(value);
break;
default:
front.resolve({
value: value,
done: false
});
break;
}
front = front.next;
if (front) {
resume(front.key, front.arg);
} else {
back = null;
}
}
this._invoke = send;
if (typeof gen.return !== "function") {
this.return = undefined;
}
}
if (typeof Symbol === "function" && Symbol.asyncIterator) {
_AsyncGenerator.prototype[Symbol.asyncIterator] = function () {
return this;
};
}
_AsyncGenerator.prototype.next = function (arg) {
return this._invoke("next", arg);
};
_AsyncGenerator.prototype.throw = function (arg) {
return this._invoke("throw", arg);
};
_AsyncGenerator.prototype.return = function (arg) {
return this._invoke("return", arg);
};
function _wrapAsyncGenerator(fn) {
return function () {
return new _AsyncGenerator(fn.apply(this, arguments));
};
}
function _awaitAsyncGenerator(value) {
return new _AwaitValue(value);
}
function _asyncGeneratorDelegate(inner, awaitWrap) {
var iter = {},
waiting = false;
function pump(key, value) {
waiting = true;
value = new Promise(function (resolve) {
resolve(inner[key](value));
});
return {
done: false,
value: awaitWrap(value)
};
}
;
if (typeof Symbol === "function" && Symbol.iterator) {
iter[Symbol.iterator] = function () {
return this;
};
}
iter.next = function (value) {
if (waiting) {
waiting = false;
return value;
}
return pump("next", value);
};
if (typeof inner.throw === "function") {
iter.throw = function (value) {
if (waiting) {
waiting = false;
throw value;
}
return pump("throw", value);
};
}
if (typeof inner.return === "function") {
iter.return = function (value) {
if (waiting) {
waiting = false;
return value;
}
return pump("return", value);
};
}
return iter;
}
function asyncGeneratorStep(gen, resolve, reject, _next, _throw, key, arg) {
try {
var info = gen[key](arg);
var value = info.value;
} catch (error) {
reject(error);
return;
}
if (info.done) {
resolve(value);
} else {
Promise.resolve(value).then(_next, _throw);
}
}
function _asyncToGenerator(fn) {
return function () {
var self = this,
args = arguments;
return new Promise(function (resolve, reject) {
var gen = fn.apply(self, args);
function _next(value) {
asyncGeneratorStep(gen, resolve, reject, _next, _throw, "next", value);
}
function _throw(err) {
asyncGeneratorStep(gen, resolve, reject, _next, _throw, "throw", err);
}
_next(undefined);
});
};
}
function _classCallCheck(instance, Constructor) {
if (!(instance instanceof Constructor)) {
throw new TypeError("Cannot call a class as a function");
}
}
function _defineProperties(target, props) {
for (var i = 0; i < props.length; i++) {
var descriptor = props[i];
descriptor.enumerable = descriptor.enumerable || false;
descriptor.configurable = true;
if ("value" in descriptor) descriptor.writable = true;
Object.defineProperty(target, descriptor.key, descriptor);
}
}
function _createClass(Constructor, protoProps, staticProps) {
if (protoProps) _defineProperties(Constructor.prototype, protoProps);
if (staticProps) _defineProperties(Constructor, staticProps);
return Constructor;
}
function _defineEnumerableProperties(obj, descs) {
for (var key in descs) {
var desc = descs[key];
desc.configurable = desc.enumerable = true;
if ("value" in desc) desc.writable = true;
Object.defineProperty(obj, key, desc);
}
if (Object.getOwnPropertySymbols) {
var objectSymbols = Object.getOwnPropertySymbols(descs);
for (var i = 0; i < objectSymbols.length; i++) {
var sym = objectSymbols[i];
var desc = descs[sym];
desc.configurable = desc.enumerable = true;
if ("value" in desc) desc.writable = true;
Object.defineProperty(obj, sym, desc);
}
}
return obj;
}
function _defaults(obj, defaults) {
var keys = Object.getOwnPropertyNames(defaults);
for (var i = 0; i < keys.length; i++) {
var key = keys[i];
var value = Object.getOwnPropertyDescriptor(defaults, key);
if (value && value.configurable && obj[key] === undefined) {
Object.defineProperty(obj, key, value);
}
}
return obj;
}
function _defineProperty(obj, key, value) {
if (key in obj) {
Object.defineProperty(obj, key, {
value: value,
enumerable: true,
configurable: true,
writable: true
});
} else {
obj[key] = value;
}
return obj;
}
function _extends() {
_extends = Object.assign || function (target) {
for (var i = 1; i < arguments.length; i++) {
var source = arguments[i];
for (var key in source) {
if (Object.prototype.hasOwnProperty.call(source, key)) {
target[key] = source[key];
}
}
}
return target;
};
return _extends.apply(this, arguments);
}
function _objectSpread(target) {
for (var i = 1; i < arguments.length; i++) {
var source = arguments[i] != null ? Object(arguments[i]) : {};
var ownKeys = Object.keys(source);
if (typeof Object.getOwnPropertySymbols === 'function') {
ownKeys = ownKeys.concat(Object.getOwnPropertySymbols(source).filter(function (sym) {
return Object.getOwnPropertyDescriptor(source, sym).enumerable;
}));
}
ownKeys.forEach(function (key) {
_defineProperty(target, key, source[key]);
});
}
return target;
}
function ownKeys$1(object, enumerableOnly) {
var keys = Object.keys(object);
if (Object.getOwnPropertySymbols) {
var symbols = Object.getOwnPropertySymbols(object);
if (enumerableOnly) symbols = symbols.filter(function (sym) {
return Object.getOwnPropertyDescriptor(object, sym).enumerable;
});
keys.push.apply(keys, symbols);
}
return keys;
}
function _objectSpread2(target) {
for (var i = 1; i < arguments.length; i++) {
var source = arguments[i] != null ? arguments[i] : {};
if (i % 2) {
ownKeys$1(Object(source), true).forEach(function (key) {
_defineProperty(target, key, source[key]);
});
} else if (Object.getOwnPropertyDescriptors) {
Object.defineProperties(target, Object.getOwnPropertyDescriptors(source));
} else {
ownKeys$1(Object(source)).forEach(function (key) {
Object.defineProperty(target, key, Object.getOwnPropertyDescriptor(source, key));
});
}
}
return target;
}
function _inherits(subClass, superClass) {
if (typeof superClass !== "function" && superClass !== null) {
throw new TypeError("Super expression must either be null or a function");
}
subClass.prototype = Object.create(superClass && superClass.prototype, {
constructor: {
value: subClass,
writable: true,
configurable: true
}
});
if (superClass) _setPrototypeOf(subClass, superClass);
}
function _inheritsLoose(subClass, superClass) {
subClass.prototype = Object.create(superClass.prototype);
subClass.prototype.constructor = subClass;
_setPrototypeOf(subClass, superClass);
}
function _getPrototypeOf(o) {
_getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf : function _getPrototypeOf(o) {
return o.__proto__ || Object.getPrototypeOf(o);
};
return _getPrototypeOf(o);
}
function _setPrototypeOf(o, p) {
_setPrototypeOf = Object.setPrototypeOf || function _setPrototypeOf(o, p) {
o.__proto__ = p;
return o;
};
return _setPrototypeOf(o, p);
}
function _isNativeReflectConstruct() {
if (typeof Reflect === "undefined" || !Reflect.construct) return false;
if (Reflect.construct.sham) return false;
if (typeof Proxy === "function") return true;
try {
Boolean.prototype.valueOf.call(Reflect.construct(Boolean, [], function () {}));
return true;
} catch (e) {
return false;
}
}
function _construct(Parent, args, Class) {
if (_isNativeReflectConstruct()) {
_construct = Reflect.construct;
} else {
_construct = function _construct(Parent, args, Class) {
var a = [null];
a.push.apply(a, args);
var Constructor = Function.bind.apply(Parent, a);
var instance = new Constructor();
if (Class) _setPrototypeOf(instance, Class.prototype);
return instance;
};
}
return _construct.apply(null, arguments);
}
function _isNativeFunction(fn) {
return Function.toString.call(fn).indexOf("[native code]") !== -1;
}
function _wrapNativeSuper(Class) {
var _cache = typeof Map === "function" ? new Map() : undefined;
_wrapNativeSuper = function _wrapNativeSuper(Class) {
if (Class === null || !_isNativeFunction(Class)) return Class;
if (typeof Class !== "function") {
throw new TypeError("Super expression must either be null or a function");
}
if (typeof _cache !== "undefined") {
if (_cache.has(Class)) return _cache.get(Class);
_cache.set(Class, Wrapper);
}
function Wrapper() {
return _construct(Class, arguments, _getPrototypeOf(this).constructor);
}
Wrapper.prototype = Object.create(Class.prototype, {
constructor: {
value: Wrapper,
enumerable: false,
writable: true,
configurable: true
}
});
return _setPrototypeOf(Wrapper, Class);
};
return _wrapNativeSuper(Class);
}
function _instanceof(left, right) {
if (right != null && typeof Symbol !== "undefined" && right[Symbol.hasInstance]) {
return !!right[Symbol.hasInstance](left);
} else {
return left instanceof right;
}
}
function _interopRequireDefault(obj) {
return obj && obj.__esModule ? obj : {
default: obj
};
}
function _getRequireWildcardCache() {
if (typeof WeakMap !== "function") return null;
var cache = new WeakMap();
_getRequireWildcardCache = function () {
return cache;
};
return cache;
}
function _interopRequireWildcard(obj) {
if (obj && obj.__esModule) {
return obj;
}
if (obj === null || typeof obj !== "object" && typeof obj !== "function") {
return {
default: obj
};
}
var cache = _getRequireWildcardCache();
if (cache && cache.has(obj)) {
return cache.get(obj);
}
var newObj = {};
var hasPropertyDescriptor = Object.defineProperty && Object.getOwnPropertyDescriptor;
for (var key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
var desc = hasPropertyDescriptor ? Object.getOwnPropertyDescriptor(obj, key) : null;
if (desc && (desc.get || desc.set)) {
Object.defineProperty(newObj, key, desc);
} else {
newObj[key] = obj[key];
}
}
}
newObj.default = obj;
if (cache) {
cache.set(obj, newObj);
}
return newObj;
}
function _newArrowCheck(innerThis, boundThis) {
if (innerThis !== boundThis) {
throw new TypeError("Cannot instantiate an arrow function");
}
}
function _objectDestructuringEmpty(obj) {
if (obj == null) throw new TypeError("Cannot destructure undefined");
}
function _objectWithoutPropertiesLoose(source, excluded) {
if (source == null) return {};
var target = {};
var sourceKeys = Object.keys(source);
var key, i;
for (i = 0; i < sourceKeys.length; i++) {
key = sourceKeys[i];
if (excluded.indexOf(key) >= 0) continue;
target[key] = source[key];
}
return target;
}
function _objectWithoutProperties(source, excluded) {
if (source == null) return {};
var target = _objectWithoutPropertiesLoose(source, excluded);
var key, i;
if (Object.getOwnPropertySymbols) {
var sourceSymbolKeys = Object.getOwnPropertySymbols(source);
for (i = 0; i < sourceSymbolKeys.length; i++) {
key = sourceSymbolKeys[i];
if (excluded.indexOf(key) >= 0) continue;
if (!Object.prototype.propertyIsEnumerable.call(source, key)) continue;
target[key] = source[key];
}
}
return target;
}
function _assertThisInitialized(self) {
if (self === void 0) {
throw new ReferenceError("this hasn't been initialised - super() hasn't been called");
}
return self;
}
function _possibleConstructorReturn(self, call) {
if (call && (typeof call === "object" || typeof call === "function")) {
return call;
}
return _assertThisInitialized(self);
}
function _createSuper(Derived) {
var hasNativeReflectConstruct = _isNativeReflectConstruct();
return function _createSuperInternal() {
var Super = _getPrototypeOf(Derived),
result;
if (hasNativeReflectConstruct) {
var NewTarget = _getPrototypeOf(this).constructor;
result = Reflect.construct(Super, arguments, NewTarget);
} else {
result = Super.apply(this, arguments);
}
return _possibleConstructorReturn(this, result);
};
}
function _superPropBase(object, property) {
while (!Object.prototype.hasOwnProperty.call(object, property)) {
object = _getPrototypeOf(object);
if (object === null) break;
}
return object;
}
function _get(target, property, receiver) {
if (typeof Reflect !== "undefined" && Reflect.get) {
_get = Reflect.get;
} else {
_get = function _get(target, property, receiver) {
var base = _superPropBase(target, property);
if (!base) return;
var desc = Object.getOwnPropertyDescriptor(base, property);
if (desc.get) {
return desc.get.call(receiver);
}
return desc.value;
};
}
return _get(target, property, receiver || target);
}
function set$4(target, property, value, receiver) {
if (typeof Reflect !== "undefined" && Reflect.set) {
set$4 = Reflect.set;
} else {
set$4 = function set(target, property, value, receiver) {
var base = _superPropBase(target, property);
var desc;
if (base) {
desc = Object.getOwnPropertyDescriptor(base, property);
if (desc.set) {
desc.set.call(receiver, value);
return true;
} else if (!desc.writable) {
return false;
}
}
desc = Object.getOwnPropertyDescriptor(receiver, property);
if (desc) {
if (!desc.writable) {
return false;
}
desc.value = value;
Object.defineProperty(receiver, property, desc);
} else {
_defineProperty(receiver, property, value);
}
return true;
};
}
return set$4(target, property, value, receiver);
}
function _set(target, property, value, receiver, isStrict) {
var s = set$4(target, property, value, receiver || target);
if (!s && isStrict) {
throw new Error('failed to set property');
}
return value;
}
function _taggedTemplateLiteral(strings, raw) {
if (!raw) {
raw = strings.slice(0);
}
return Object.freeze(Object.defineProperties(strings, {
raw: {
value: Object.freeze(raw)
}
}));
}
function _taggedTemplateLiteralLoose(strings, raw) {
if (!raw) {
raw = strings.slice(0);
}
strings.raw = raw;
return strings;
}
function _readOnlyError(name) {
throw new TypeError("\"" + name + "\" is read-only");
}
function _writeOnlyError(name) {
throw new TypeError("\"" + name + "\" is write-only");
}
function _classNameTDZError(name) {
throw new Error("Class \"" + name + "\" cannot be referenced in computed property keys.");
}
function _temporalUndefined() {}
function _tdz(name) {
throw new ReferenceError(name + " is not defined - temporal dead zone");
}
function _temporalRef(val, name) {
return val === _temporalUndefined ? _tdz(name) : val;
}
function _slicedToArray(arr, i) {
return _arrayWithHoles(arr) || _iterableToArrayLimit(arr, i) || _unsupportedIterableToArray(arr, i) || _nonIterableRest();
}
function _slicedToArrayLoose(arr, i) {
return _arrayWithHoles(arr) || _iterableToArrayLimitLoose(arr, i) || _unsupportedIterableToArray(arr, i) || _nonIterableRest();
}
function _toArray(arr) {
return _arrayWithHoles(arr) || _iterableToArray(arr) || _unsupportedIterableToArray(arr) || _nonIterableRest();
}
function _toConsumableArray(arr) {
return _arrayWithoutHoles(arr) || _iterableToArray(arr) || _unsupportedIterableToArray(arr) || _nonIterableSpread();
}
function _arrayWithoutHoles(arr) {
if (Array.isArray(arr)) return _arrayLikeToArray(arr);
}
function _arrayWithHoles(arr) {
if (Array.isArray(arr)) return arr;
}
function _maybeArrayLike(next, arr, i) {
if (arr && !Array.isArray(arr) && typeof arr.length === "number") {
var len = arr.length;
return _arrayLikeToArray(arr, i !== void 0 && i < len ? i : len);
}
return next(arr, i);
}
function _iterableToArray(iter) {
if (typeof Symbol !== "undefined" && Symbol.iterator in Object(iter)) return Array.from(iter);
}
function _iterableToArrayLimit(arr, i) {
if (typeof Symbol === "undefined" || !(Symbol.iterator in Object(arr))) return;
var _arr = [];
var _n = true;
var _d = false;
var _e = undefined;
try {
for (var _i = arr[Symbol.iterator](), _s; !(_n = (_s = _i.next()).done); _n = true) {
_arr.push(_s.value);
if (i && _arr.length === i) break;
}
} catch (err) {
_d = true;
_e = err;
} finally {
try {
if (!_n && _i["return"] != null) _i["return"]();
} finally {
if (_d) throw _e;
}
}
return _arr;
}
function _iterableToArrayLimitLoose(arr, i) {
if (typeof Symbol === "undefined" || !(Symbol.iterator in Object(arr))) return;
var _arr = [];
for (var _iterator = arr[Symbol.iterator](), _step; !(_step = _iterator.next()).done;) {
_arr.push(_step.value);
if (i && _arr.length === i) break;
}
return _arr;
}
function _unsupportedIterableToArray(o, minLen) {
if (!o) return;
if (typeof o === "string") return _arrayLikeToArray(o, minLen);
var n = Object.prototype.toString.call(o).slice(8, -1);
if (n === "Object" && o.constructor) n = o.constructor.name;
if (n === "Map" || n === "Set") return Array.from(o);
if (n === "Arguments" || /^(?:Ui|I)nt(?:8|16|32)(?:Clamped)?Array$/.test(n)) return _arrayLikeToArray(o, minLen);
}
function _arrayLikeToArray(arr, len) {
if (len == null || len > arr.length) len = arr.length;
for (var i = 0, arr2 = new Array(len); i < len; i++) arr2[i] = arr[i];
return arr2;
}
function _nonIterableSpread() {
throw new TypeError("Invalid attempt to spread non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
}
function _nonIterableRest() {
throw new TypeError("Invalid attempt to destructure non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
}
function _createForOfIteratorHelper(o, allowArrayLike) {
var it;
if (typeof Symbol === "undefined" || o[Symbol.iterator] == null) {
if (Array.isArray(o) || (it = _unsupportedIterableToArray(o)) || allowArrayLike && o && typeof o.length === "number") {
if (it) o = it;
var i = 0;
var F = function () {};
return {
s: F,
n: function () {
if (i >= o.length) return {
done: true
};
return {
done: false,
value: o[i++]
};
},
e: function (e) {
throw e;
},
f: F
};
}
throw new TypeError("Invalid attempt to iterate non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
}
var normalCompletion = true,
didErr = false,
err;
return {
s: function () {
it = o[Symbol.iterator]();
},
n: function () {
var step = it.next();
normalCompletion = step.done;
return step;
},
e: function (e) {
didErr = true;
err = e;
},
f: function () {
try {
if (!normalCompletion && it.return != null) it.return();
} finally {
if (didErr) throw err;
}
}
};
}
function _createForOfIteratorHelperLoose(o, allowArrayLike) {
var it;
if (typeof Symbol === "undefined" || o[Symbol.iterator] == null) {
if (Array.isArray(o) || (it = _unsupportedIterableToArray(o)) || allowArrayLike && o && typeof o.length === "number") {
if (it) o = it;
var i = 0;
return function () {
if (i >= o.length) return {
done: true
};
return {
done: false,
value: o[i++]
};
};
}
throw new TypeError("Invalid attempt to iterate non-iterable instance.\nIn order to be iterable, non-array objects must have a [Symbol.iterator]() method.");
}
it = o[Symbol.iterator]();
return it.next.bind(it);
}
function _skipFirstGeneratorNext(fn) {
return function () {
var it = fn.apply(this, arguments);
it.next();
return it;
};
}
function _toPrimitive(input, hint) {
if (typeof input !== "object" || input === null) return input;
var prim = input[Symbol.toPrimitive];
if (prim !== undefined) {
var res = prim.call(input, hint || "default");
if (typeof res !== "object") return res;
throw new TypeError("@@toPrimitive must return a primitive value.");
}
return (hint === "string" ? String : Number)(input);
}
function _toPropertyKey(arg) {
var key = _toPrimitive(arg, "string");
return typeof key === "symbol" ? key : String(key);
}
function _initializerWarningHelper(descriptor, context) {
throw new Error('Decorating class property failed. Please ensure that ' + 'proposal-class-properties is enabled and runs after the decorators transform.');
}
function _initializerDefineProperty(target, property, descriptor, context) {
if (!descriptor) return;
Object.defineProperty(target, property, {
enumerable: descriptor.enumerable,
configurable: descriptor.configurable,
writable: descriptor.writable,
value: descriptor.initializer ? descriptor.initializer.call(context) : void 0
});
}
function _applyDecoratedDescriptor(target, property, decorators, descriptor, context) {
var desc = {};
Object.keys(descriptor).forEach(function (key) {
desc[key] = descriptor[key];
});
desc.enumerable = !!desc.enumerable;
desc.configurable = !!desc.configurable;
if ('value' in desc || desc.initializer) {
desc.writable = true;
}
desc = decorators.slice().reverse().reduce(function (desc, decorator) {
return decorator(target, property, desc) || desc;
}, desc);
if (context && desc.initializer !== void 0) {
desc.value = desc.initializer ? desc.initializer.call(context) : void 0;
desc.initializer = undefined;
}
if (desc.initializer === void 0) {
Object.defineProperty(target, property, desc);
desc = null;
}
return desc;
}
var id$2 = 0;
function _classPrivateFieldLooseKey(name) {
return "__private_" + id$2++ + "_" + name;
}
function _classPrivateFieldLooseBase(receiver, privateKey) {
if (!Object.prototype.hasOwnProperty.call(receiver, privateKey)) {
throw new TypeError("attempted to use private field on non-instance");
}
return receiver;
}
function _classPrivateFieldGet(receiver, privateMap) {
var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "get");
return _classApplyDescriptorGet(receiver, descriptor);
}
function _classPrivateFieldSet(receiver, privateMap, value) {
var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "set");
_classApplyDescriptorSet(receiver, descriptor, value);
return value;
}
function _classPrivateFieldDestructureSet(receiver, privateMap) {
var descriptor = _classExtractFieldDescriptor(receiver, privateMap, "set");
return _classApplyDescriptorDestructureSet(receiver, descriptor);
}
function _classExtractFieldDescriptor(receiver, privateMap, action) {
if (!privateMap.has(receiver)) {
throw new TypeError("attempted to " + action + " private field on non-instance");
}
return privateMap.get(receiver);
}
function _classStaticPrivateFieldSpecGet(receiver, classConstructor, descriptor) {
_classCheckPrivateStaticAccess(receiver, classConstructor);
_classCheckPrivateStaticFieldDescriptor(descriptor, "get");
return _classApplyDescriptorGet(receiver, descriptor);
}
function _classStaticPrivateFieldSpecSet(receiver, classConstructor, descriptor, value) {
_classCheckPrivateStaticAccess(receiver, classConstructor);
_classCheckPrivateStaticFieldDescriptor(descriptor, "set");
_classApplyDescriptorSet(receiver, descriptor, value);
return value;
}
function _classStaticPrivateMethodGet(receiver, classConstructor, method) {
_classCheckPrivateStaticAccess(receiver, classConstructor);
return method;
}
function _classStaticPrivateMethodSet() {
throw new TypeError("attempted to set read only static private field");
}
function _classApplyDescriptorGet(receiver, descriptor) {
if (descriptor.get) {
return descriptor.get.call(receiver);
}
return descriptor.value;
}
function _classApplyDescriptorSet(receiver, descriptor, value) {
if (descriptor.set) {
descriptor.set.call(receiver, value);
} else {
if (!descriptor.writable) {
throw new TypeError("attempted to set read only private field");
}
descriptor.value = value;
}
}
function _classApplyDescriptorDestructureSet(receiver, descriptor) {
if (descriptor.set) {
if (!("__destrObj" in descriptor)) {
descriptor.__destrObj = {
set value(v) {
descriptor.set.call(receiver, v);
}
};
}
return descriptor.__destrObj;
} else {
if (!descriptor.writable) {
throw new TypeError("attempted to set read only private field");
}
return descriptor;
}
}
function _classStaticPrivateFieldDestructureSet(receiver, classConstructor, descriptor) {
_classCheckPrivateStaticAccess(receiver, classConstructor);
_classCheckPrivateStaticFieldDescriptor(descriptor, "set");
return _classApplyDescriptorDestructureSet(receiver, descriptor);
}
function _classCheckPrivateStaticAccess(receiver, classConstructor) {
if (receiver !== classConstructor) {
throw new TypeError("Private static access of wrong provenance");
}
}
function _classCheckPrivateStaticFieldDescriptor(descriptor, action) {
if (descriptor === undefined) {
throw new TypeError("attempted to " + action + " private static field before its declaration");
}
}
function _decorate(decorators, factory, superClass, mixins) {
var api = _getDecoratorsApi();
if (mixins) {
for (var i = 0; i < mixins.length; i++) {
api = mixins[i](api);
}
}
var r = factory(function initialize(O) {
api.initializeInstanceElements(O, decorated.elements);
}, superClass);
var decorated = api.decorateClass(_coalesceClassElements(r.d.map(_createElementDescriptor)), decorators);
api.initializeClassElements(r.F, decorated.elements);
return api.runClassFinishers(r.F, decorated.finishers);
}
function _getDecoratorsApi() {
_getDecoratorsApi = function () {
return api;
};
var api = {
elementsDefinitionOrder: [["method"], ["field"]],
initializeInstanceElements: function (O, elements) {
["method", "field"].forEach(function (kind) {
elements.forEach(function (element) {
if (element.kind === kind && element.placement === "own") {
this.defineClassElement(O, element);
}
}, this);
}, this);
},
initializeClassElements: function (F, elements) {
var proto = F.prototype;
["method", "field"].forEach(function (kind) {
elements.forEach(function (element) {
var placement = element.placement;
if (element.kind === kind && (placement === "static" || placement === "prototype")) {
var receiver = placement === "static" ? F : proto;
this.defineClassElement(receiver, element);
}
}, this);
}, this);
},
defineClassElement: function (receiver, element) {
var descriptor = element.descriptor;
if (element.kind === "field") {
var initializer = element.initializer;
descriptor = {
enumerable: descriptor.enumerable,
writable: descriptor.writable,
configurable: descriptor.configurable,
value: initializer === void 0 ? void 0 : initializer.call(receiver)
};
}
Object.defineProperty(receiver, element.key, descriptor);
},
decorateClass: function (elements, decorators) {
var newElements = [];
var finishers = [];
var placements = {
static: [],
prototype: [],
own: []
};
elements.forEach(function (element) {
this.addElementPlacement(element, placements);
}, this);
elements.forEach(function (element) {
if (!_hasDecorators(element)) return newElements.push(element);
var elementFinishersExtras = this.decorateElement(element, placements);
newElements.push(elementFinishersExtras.element);
newElements.push.apply(newElements, elementFinishersExtras.extras);
finishers.push.apply(finishers, elementFinishersExtras.finishers);
}, this);
if (!decorators) {
return {
elements: newElements,
finishers: finishers
};
}
var result = this.decorateConstructor(newElements, decorators);
finishers.push.apply(finishers, result.finishers);
result.finishers = finishers;
return result;
},
addElementPlacement: function (element, placements, silent) {
var keys = placements[element.placement];
if (!silent && keys.indexOf(element.key) !== -1) {
throw new TypeError("Duplicated element (" + element.key + ")");
}
keys.push(element.key);
},
decorateElement: function (element, placements) {
var extras = [];
var finishers = [];
for (var decorators = element.decorators, i = decorators.length - 1; i >= 0; i--) {
var keys = placements[element.placement];
keys.splice(keys.indexOf(element.key), 1);
var elementObject = this.fromElementDescriptor(element);
var elementFinisherExtras = this.toElementFinisherExtras((0, decorators[i])(elementObject) || elementObject);
element = elementFinisherExtras.element;
this.addElementPlacement(element, placements);
if (elementFinisherExtras.finisher) {
finishers.push(elementFinisherExtras.finisher);
}
var newExtras = elementFinisherExtras.extras;
if (newExtras) {
for (var j = 0; j < newExtras.length; j++) {
this.addElementPlacement(newExtras[j], placements);
}
extras.push.apply(extras, newExtras);
}
}
return {
element: element,
finishers: finishers,
extras: extras
};
},
decorateConstructor: function (elements, decorators) {
var finishers = [];
for (var i = decorators.length - 1; i >= 0; i--) {
var obj = this.fromClassDescriptor(elements);
var elementsAndFinisher = this.toClassDescriptor((0, decorators[i])(obj) || obj);
if (elementsAndFinisher.finisher !== undefined) {
finishers.push(elementsAndFinisher.finisher);
}
if (elementsAndFinisher.elements !== undefined) {
elements = elementsAndFinisher.elements;
for (var j = 0; j < elements.length - 1; j++) {
for (var k = j + 1; k < elements.length; k++) {
if (elements[j].key === elements[k].key && elements[j].placement === elements[k].placement) {
throw new TypeError("Duplicated element (" + elements[j].key + ")");
}
}
}
}
}
return {
elements: elements,
finishers: finishers
};
},
fromElementDescriptor: function (element) {
var obj = {
kind: element.kind,
key: element.key,
placement: element.placement,
descriptor: element.descriptor
};
var desc = {
value: "Descriptor",
configurable: true
};
Object.defineProperty(obj, Symbol.toStringTag, desc);
if (element.kind === "field") obj.initializer = element.initializer;
return obj;
},
toElementDescriptors: function (elementObjects) {
if (elementObjects === undefined) return;
return _toArray(elementObjects).map(function (elementObject) {
var element = this.toElementDescriptor(elementObject);
this.disallowProperty(elementObject, "finisher", "An element descriptor");
this.disallowProperty(elementObject, "extras", "An element descriptor");
return element;
}, this);
},
toElementDescriptor: function (elementObject) {
var kind = String(elementObject.kind);
if (kind !== "method" && kind !== "field") {
throw new TypeError('An element descriptor\'s .kind property must be either "method" or' + ' "field", but a decorator created an element descriptor with' + ' .kind "' + kind + '"');
}
var key = _toPropertyKey(elementObject.key);
var placement = String(elementObject.placement);
if (placement !== "static" && placement !== "prototype" && placement !== "own") {
throw new TypeError('An element descriptor\'s .placement property must be one of "static",' + ' "prototype" or "own", but a decorator created an element descriptor' + ' with .placement "' + placement + '"');
}
var descriptor = elementObject.descriptor;
this.disallowProperty(elementObject, "elements", "An element descriptor");
var element = {
kind: kind,
key: key,
placement: placement,
descriptor: Object.assign({}, descriptor)
};
if (kind !== "field") {
this.disallowProperty(elementObject, "initializer", "A method descriptor");
} else {
this.disallowProperty(descriptor, "get", "The property descriptor of a field descriptor");
this.disallowProperty(descriptor, "set", "The property descriptor of a field descriptor");
this.disallowProperty(descriptor, "value", "The property descriptor of a field descriptor");
element.initializer = elementObject.initializer;
}
return element;
},
toElementFinisherExtras: function (elementObject) {
var element = this.toElementDescriptor(elementObject);
var finisher = _optionalCallableProperty(elementObject, "finisher");
var extras = this.toElementDescriptors(elementObject.extras);
return {
element: element,
finisher: finisher,
extras: extras
};
},
fromClassDescriptor: function (elements) {
var obj = {
kind: "class",
elements: elements.map(this.fromElementDescriptor, this)
};
var desc = {
value: "Descriptor",
configurable: true
};
Object.defineProperty(obj, Symbol.toStringTag, desc);
return obj;
},
toClassDescriptor: function (obj) {
var kind = String(obj.kind);
if (kind !== "class") {
throw new TypeError('A class descriptor\'s .kind property must be "class", but a decorator' + ' created a class descriptor with .kind "' + kind + '"');
}
this.disallowProperty(obj, "key", "A class descriptor");
this.disallowProperty(obj, "placement", "A class descriptor");
this.disallowProperty(obj, "descriptor", "A class descriptor");
this.disallowProperty(obj, "initializer", "A class descriptor");
this.disallowProperty(obj, "extras", "A class descriptor");
var finisher = _optionalCallableProperty(obj, "finisher");
var elements = this.toElementDescriptors(obj.elements);
return {
elements: elements,
finisher: finisher
};
},
runClassFinishers: function (constructor, finishers) {
for (var i = 0; i < finishers.length; i++) {
var newConstructor = (0, finishers[i])(constructor);
if (newConstructor !== undefined) {
if (typeof newConstructor !== "function") {
throw new TypeError("Finishers must return a constructor.");
}
constructor = newConstructor;
}
}
return constructor;
},
disallowProperty: function (obj, name, objectType) {
if (obj[name] !== undefined) {
throw new TypeError(objectType + " can't have a ." + name + " property.");
}
}
};
return api;
}
function _createElementDescriptor(def) {
var key = _toPropertyKey(def.key);
var descriptor;
if (def.kind === "method") {
descriptor = {
value: def.value,
writable: true,
configurable: true,
enumerable: false
};
} else if (def.kind === "get") {
descriptor = {
get: def.value,
configurable: true,
enumerable: false
};
} else if (def.kind === "set") {
descriptor = {
set: def.value,
configurable: true,
enumerable: false
};
} else if (def.kind === "field") {
descriptor = {
configurable: true,
writable: true,
enumerable: true
};
}
var element = {
kind: def.kind === "field" ? "field" : "method",
key: key,
placement: def.static ? "static" : def.kind === "field" ? "own" : "prototype",
descriptor: descriptor
};
if (def.decorators) element.decorators = def.decorators;
if (def.kind === "field") element.initializer = def.value;
return element;
}
function _coalesceGetterSetter(element, other) {
if (element.descriptor.get !== undefined) {
other.descriptor.get = element.descriptor.get;
} else {
other.descriptor.set = element.descriptor.set;
}
}
function _coalesceClassElements(elements) {
var newElements = [];
var isSameElement = function (other) {
return other.kind === "method" && other.key === element.key && other.placement === element.placement;
};
for (var i = 0; i < elements.length; i++) {
var element = elements[i];
var other;
if (element.kind === "method" && (other = newElements.find(isSameElement))) {
if (_isDataDescriptor(element.descriptor) || _isDataDescriptor(other.descriptor)) {
if (_hasDecorators(element) || _hasDecorators(other)) {
throw new ReferenceError("Duplicated methods (" + element.key + ") can't be decorated.");
}
other.descriptor = element.descriptor;
} else {
if (_hasDecorators(element)) {
if (_hasDecorators(other)) {
throw new ReferenceError("Decorators can't be placed on different accessors with for " + "the same property (" + element.key + ").");
}
other.decorators = element.decorators;
}
_coalesceGetterSetter(element, other);
}
} else {
newElements.push(element);
}
}
return newElements;
}
function _hasDecorators(element) {
return element.decorators && element.decorators.length;
}
function _isDataDescriptor(desc) {
return desc !== undefined && !(desc.value === undefined && desc.writable === undefined);
}
function _optionalCallableProperty(obj, name) {
var value = obj[name];
if (value !== undefined && typeof value !== "function") {
throw new TypeError("Expected '" + name + "' to be a function");
}
return value;
}
function _classPrivateMethodGet(receiver, privateSet, fn) {
if (!privateSet.has(receiver)) {
throw new TypeError("attempted to get private field on non-instance");
}
return fn;
}
function _classPrivateMethodSet() {
throw new TypeError("attempted to reassign private method");
}
function _wrapRegExp(re, groups) {
_wrapRegExp = function (re, groups) {
return new BabelRegExp(re, undefined, groups);
};
var _RegExp = _wrapNativeSuper(RegExp);
var _super = RegExp.prototype;
var _groups = new WeakMap();
function BabelRegExp(re, flags, groups) {
var _this = _RegExp.call(this, re, flags);
_groups.set(_this, groups || _groups.get(re));
return _this;
}
_inherits(BabelRegExp, _RegExp);
BabelRegExp.prototype.exec = function (str) {
var result = _super.exec.call(this, str);
if (result) result.groups = buildGroups(result, this);
return result;
};
BabelRegExp.prototype[Symbol.replace] = function (str, substitution) {
if (typeof substitution === "string") {
var groups = _groups.get(this);
return _super[Symbol.replace].call(this, str, substitution.replace(/\$<([^>]+)>/g, function (_, name) {
return "$" + groups[name];
}));
} else if (typeof substitution === "function") {
var _this = this;
return _super[Symbol.replace].call(this, str, function () {
var args = [];
args.push.apply(args, arguments);
if (typeof args[args.length - 1] !== "object") {
args.push(buildGroups(args, _this));
}
return substitution.apply(this, args);
});
} else {
return _super[Symbol.replace].call(this, str, substitution);
}
};
function buildGroups(result, re) {
var g = _groups.get(re);
return Object.keys(g).reduce(function (groups, name) {
groups[name] = result[g[name]];
return groups;
}, Object.create(null));
}
return _wrapRegExp.apply(this, arguments);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EPSILON_FLOAT32 = 1e-7;
var EPSILON_FLOAT16 = 1e-4;
/** Convenient class for storing tensor-related data. */
var DataStorage = /*#__PURE__*/function () {
function DataStorage(backend, dataMover) {
this.backend = backend;
this.dataMover = dataMover;
this.data = new WeakMap();
this.dataIdsCount = 0;
}
var _proto = DataStorage.prototype;
_proto.get = function get(dataId) {
if (!this.data.has(dataId)) {
this.dataMover.moveData(this.backend, dataId);
}
return this.data.get(dataId);
};
_proto.set = function set(dataId, value) {
this.dataIdsCount++;
this.data.set(dataId, value);
};
_proto.has = function has(dataId) {
return this.data.has(dataId);
};
_proto.delete = function _delete(dataId) {
this.dataIdsCount--;
return this.data.delete(dataId);
};
_proto.numDataIds = function numDataIds() {
return this.dataIdsCount;
};
return DataStorage;
}();
/**
* The interface that defines the kernels that should be implemented when
* adding a new backend. New backends don't need to implement every one of the
* methods, this can be done gradually (throw an error for unimplemented
* methods).
*/
var KernelBackend = /*#__PURE__*/function () {
function KernelBackend() {}
var _proto2 = KernelBackend.prototype;
_proto2.refCount = function refCount(dataId) {
return notYetImplemented('refCount');
};
_proto2.incRef = function incRef(dataId) {
return notYetImplemented('incRef');
};
_proto2.timerAvailable = function timerAvailable() {
return true;
};
_proto2.time = function time(f) {
return notYetImplemented('time');
};
_proto2.read = function read(dataId) {
return notYetImplemented('read');
};
_proto2.readSync = function readSync(dataId) {
return notYetImplemented('readSync');
};
_proto2.numDataIds = function numDataIds() {
return notYetImplemented('numDataIds');
};
_proto2.disposeData = function disposeData(dataId, force) {
return notYetImplemented('disposeData');
};
_proto2.write = function write(values, shape, dtype) {
return notYetImplemented('write');
};
_proto2.move = function move(dataId, values, shape, dtype, refCount) {
return notYetImplemented('move');
};
_proto2.memory = function memory() {
return notYetImplemented('memory');
}
/** Returns the highest precision for floats in bits (e.g. 16 or 32) */
;
_proto2.floatPrecision = function floatPrecision() {
return notYetImplemented('floatPrecision');
}
/** Returns the smallest representable number. */
;
_proto2.epsilon = function epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
};
_proto2.dispose = function dispose() {
return notYetImplemented('dispose');
};
return KernelBackend;
}();
function notYetImplemented(kernelName) {
throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " + "This kernel may not be supported by the tfjs backend you have chosen");
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Shuffles the array in-place using Fisher-Yates algorithm.
*
* ```js
* const a = [1, 2, 3, 4, 5];
* tf.util.shuffle(a);
* console.log(a);
* ```
*
* @param array The array to shuffle in-place.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
// tslint:disable-next-line:no-any
function shuffle(array) {
var counter = array.length;
var index = 0; // While there are elements in the array
while (counter > 0) {
// Pick a random index
index = Math.random() * counter | 0; // Decrease counter by 1
counter--; // And swap the last element with it
swap(array, counter, index);
}
}
/**
* Shuffles two arrays in-place the same way using Fisher-Yates algorithm.
*
* ```js
* const a = [1,2,3,4,5];
* const b = [11,22,33,44,55];
* tf.util.shuffleCombo(a, b);
* console.log(a, b);
* ```
*
* @param array The first array to shuffle in-place.
* @param array2 The second array to shuffle in-place with the same permutation
* as the first array.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function shuffleCombo( // tslint:disable-next-line:no-any
array, // tslint:disable-next-line:no-any
array2) {
if (array.length !== array2.length) {
throw new Error("Array sizes must match to be shuffled together " + ("First array length was " + array.length) + ("Second array length was " + array2.length));
}
var counter = array.length;
var index = 0; // While there are elements in the array
while (counter > 0) {
// Pick a random index
index = Math.random() * counter | 0; // Decrease counter by 1
counter--; // And swap the last element of each array with it
swap(array, counter, index);
swap(array2, counter, index);
}
}
/** Clamps a value to a specified range. */
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function swap(object, left, right) {
var temp = object[left];
object[left] = object[right];
object[right] = temp;
}
function sum(arr) {
var sum = 0;
for (var i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
/**
* Returns a sample from a uniform [a, b) distribution.
*
* @param a The minimum support (inclusive).
* @param b The maximum support (exclusive).
* @return A pseudorandom number on the half-open interval [a,b).
*/
function randUniform(a, b) {
var r = Math.random();
return b * r + (1 - r) * a;
}
/** Returns the squared Euclidean distance between two vectors. */
function distSquared(a, b) {
var result = 0;
for (var i = 0; i < a.length; i++) {
var diff = Number(a[i]) - Number(b[i]);
result += diff * diff;
}
return result;
}
/**
* Asserts that the expression is true. Otherwise throws an error with the
* provided message.
*
* ```js
* const x = 2;
* tf.util.assert(x === 2, 'x is not 2');
* ```
*
* @param expr The expression to assert (as a boolean).
* @param msg A function that returns the message to report when throwing an
* error. We use a function for performance reasons.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function assert(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) {
errorMessagePrefix = '';
}
assert(arraysEqual(shapeA, shapeB), function () {
return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
}
function assertNonNull(a) {
assert(a != null, function () {
return "The input to the tensor constructor must be a non-null value.";
});
} // NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function flatten(arr, result, skipTypedArray) {
if (result === void 0) {
result = [];
}
if (skipTypedArray === void 0) {
skipTypedArray = false;
}
if (result == null) {
result = [];
}
if (Array.isArray(arr) || isTypedArray$1(arr) && !skipTypedArray) {
for (var i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
result.push(arr);
}
return result;
}
/**
* Returns the size (number of elements) of the tensor given its shape.
*
* ```js
* const shape = [3, 4, 2];
* const size = tf.util.sizeFromShape(shape);
* console.log(size);
* ```
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function sizeFromShape(shape) {
if (shape.length === 0) {
// Scalar.
return 1;
}
var size = shape[0];
for (var i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function isScalarShape(shape) {
return shape.length === 0;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (var i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function tanh(x) {
// tslint:disable-next-line:no-any
if (Math.tanh != null) {
// tslint:disable-next-line:no-any
return Math.tanh(x);
}
if (x === Infinity) {
return 1;
} else if (x === -Infinity) {
return -1;
} else {
var e2x = Math.exp(2 * x);
return (e2x - 1) / (e2x + 1);
}
}
function sizeToSquarishShape(size) {
var width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
/**
* Creates a new array with randomized indicies to a given quantity.
*
* ```js
* const randomTen = tf.util.createShuffledIndices(10);
* console.log(randomTen);
* ```
*
* @param number Quantity of how many shuffled indicies to create.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function createShuffledIndices(n) {
var shuffledIndices = new Uint32Array(n);
for (var i = 0; i < n; ++i) {
shuffledIndices[i] = i;
}
shuffle(shuffledIndices);
return shuffledIndices;
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn, maxCounter) {
if (delayFn === void 0) {
delayFn = function delayFn(counter) {
return 0;
};
}
return new Promise(function (resolve, reject) {
var tryCount = 0;
var tryFn = function tryFn() {
if (checkFn()) {
resolve();
return;
}
tryCount++;
var nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
setTimeout(tryFn, nextBackoff);
};
tryFn();
});
}
/**
* Given the full size of the array and a shape that may contain -1 as the
* implicit dimension, returns the inferred shape where -1 is replaced.
* E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3].
*
* @param shape The shape, which may contain -1 in some dimension.
* @param size The full size (number of elements) of the array.
* @return The inferred shape where -1 is replaced with the inferred size.
*/
function inferFromImplicitShape(shape, size) {
var shapeProd = 1;
var implicitIdx = -1;
for (var i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
} else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + implicitIdx + " and dim " + i));
}
implicitIdx = i;
} else if (shape[i] < 0) {
throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error("Size(" + size + ") must match the product of shape " + shape);
}
return shape;
}
if (shapeProd === 0) {
throw Error("Cannot infer the missing size in [" + shape + "] when " + "there are 0 elements");
}
if (size % shapeProd !== 0) {
throw Error("The implicit shape can't be a fractional number. " + ("Got " + size + " / " + shapeProd));
}
var newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
var rank = shape.length; // Normalize input
axis = axis == null ? shape.map(function (s, i) {
return i;
}) : [].concat(axis); // Check for valid range
assert(axis.every(function (ax) {
return ax >= -rank && ax < rank;
}), function () {
return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + ("got axis " + axis);
}); // Check for only integers
assert(axis.every(function (ax) {
return isInt(ax);
}), function () {
return "All values in axis param must be integers but " + ("got axis " + axis);
}); // Handle negative axis.
return axis.map(function (a) {
return a < 0 ? rank + a : a;
});
}
/** Reduces the shape by removing all dimensions of shape 1. */
function squeezeShape(shape, axis) {
var newShape = [];
var keptDims = [];
var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
var axes = axis == null || isEmptyArray ? null : parseAxisParam(axis, shape).sort();
var j = 0;
for (var i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1");
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return {
newShape: newShape,
keptDims: keptDims
};
}
function getTypedArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
} else if (dtype === 'int32') {
values = new Int32Array(size);
} else if (dtype === 'bool') {
values = new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function getArrayFromDType(dtype, size) {
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
} else if (dtype === 'int32') {
values = new Int32Array(size);
} else if (dtype === 'bool') {
values = new Uint8Array(size);
} else if (dtype === 'string') {
values = new Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error("A tensor of type " + dtype + " being uploaded contains " + num + ".");
}
}
}
/** Returns true if the dtype is valid. */
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string';
}
/**
* Returns true if the new type can't encode the old type without loss of
* precision.
*/
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function isTypedArray$1(a) {
return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
} else if (dtype === 'complex64') {
return 8;
} else if (dtype === 'bool') {
return 1;
} else {
throw new Error("Unknown dtype " + dtype);
}
}
/**
* Returns the approximate number of bytes allocated in the string array - 2
* bytes per character. Computing the exact bytes for a native string in JS is
* not possible since it depends on the encoding of the html page that serves
* the website.
*/
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
var bytes = 0;
arr.forEach(function (x) {
return bytes += x.length;
});
return bytes;
}
/** Returns true if the value is a string. */
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
} else if (values instanceof Int32Array || values instanceof Uint8Array) {
return 'int32';
} else if (isNumber(values)) {
return 'float32';
} else if (isString(values)) {
return 'string';
} else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (var i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
var rank = shape.length;
if (rank < 2) {
return [];
} // Last dimension has implicit stride of 1, thus having D-1 (instead of D)
// strides.
var strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (var i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function createNestedArray(offset, shape, a, isComplex) {
if (isComplex === void 0) {
isComplex = false;
}
var ret = new Array();
if (shape.length === 1) {
var d = shape[0] * (isComplex ? 2 : 1);
for (var i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
} else {
var _d = shape[0];
var rest = shape.slice(1);
var len = rest.reduce(function (acc, c) {
return acc * c;
}) * (isComplex ? 2 : 1);
for (var _i = 0; _i < _d; _i++) {
ret[_i] = createNestedArray(offset + _i * len, rest, a, isComplex);
}
}
return ret;
} // Provide a nested array of TypedArray in given shape.
function toNestedArray(shape, a, isComplex) {
if (isComplex === void 0) {
isComplex = false;
}
if (shape.length === 0) {
// Scalar type should return a single number.
return a[0];
}
var size = shape.reduce(function (acc, c) {
return acc * c;
}) * (isComplex ? 2 : 1);
if (size === 0) {
// A tensor with shape zero should be turned into empty list.
return [];
}
if (size !== a.length) {
throw new Error("[" + shape + "] does not match the input size " + a.length + (isComplex ? ' for a complex tensor' : '') + ".");
}
return createNestedArray(0, shape, a, isComplex);
}
function makeOnesTypedArray(size, dtype) {
var array = makeZerosTypedArray(size, dtype);
for (var i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
} else if (dtype === 'int32') {
return new Int32Array(size);
} else if (dtype === 'bool') {
return new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
}
/**
* Make nested `TypedArray` filled with zeros.
* @param shape The shape information for the nested array.
* @param dtype dtype of the array element.
*/
function makeZerosNestedTypedArray(shape, dtype) {
var size = shape.reduce(function (prev, curr) {
return prev * curr;
}, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
} else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
} else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
} else {
throw new Error("Unknown data type " + dtype);
}
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(function (dimSize) {
assert(Number.isInteger(dimSize) && dimSize >= 0, function () {
return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + shape + "].");
});
});
}
/**
* Computes flat index for a given location (multidimentionsal index) in a
* Tensor/multidimensional array.
*
* @param locs Location in the tensor.
* @param rank Rank of the tensor.
* @param strides Tensor strides.
*/
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
} else if (rank === 1) {
return locs[0];
}
var index = locs[locs.length - 1];
for (var i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
/**
* Computes the location (multidimensional index) in a tensor/multidimentional
* array for a given flat index.
*
* @param index Index in flat array.
* @param rank Rank of tensor.
* @param strides Strides of tensor.
*/
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
} else if (rank === 1) {
return [index];
}
var locs = new Array(rank);
for (var i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
/**
* This method asserts whether an object is a Promise instance.
* @param object
*/
// tslint:disable-next-line: no-any
function isPromise(object) {
// We chose to not use 'obj instanceOf Promise' for two reasons:
// 1. It only reliably works for es6 Promise, not other Promise
// implementations.
// 2. It doesn't work with framework that uses zone.js. zone.js monkey patch
// the async calls, so it is possible the obj (patched) is comparing to a
// pre-patched Promise.
return object && object.then && typeof object.then === 'function';
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function warn() {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
var _console;
(_console = console).warn.apply(_console, arguments);
}
}
function log$9() {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
var _console2;
(_console2 = console).log.apply(_console2, arguments);
}
}
var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
/**
* The environment contains evaluated flags as well as the registered platform.
* This is always used as a global singleton and can be retrieved with
* `tf.env()`.
*
* @doc {heading: 'Environment'}
*/
var Environment = /*#__PURE__*/function () {
// tslint:disable-next-line: no-any
function Environment(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {}; // Jasmine spies on this in 'environment_test.ts'
this.getQueryParams = getQueryParams;
this.populateURLFlags();
}
var _proto = Environment.prototype;
_proto.setPlatform = function setPlatform(platformName, platform) {
if (this.platform != null) {
warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + platform + "."));
}
this.platformName = platformName;
this.platform = platform;
};
_proto.registerFlag = function registerFlag(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = {
evaluationFn: evaluationFn,
setHook: setHook
}; // Override the flag value from the URL. This has to happen here because the
// environment is initialized before flags get registered.
if (this.urlFlags[flagName] != null) {
var flagValue = this.urlFlags[flagName];
warn("Setting feature override from URL " + flagName + ": " + flagValue + ".");
this.set(flagName, flagValue);
}
};
_proto.getAsync = /*#__PURE__*/function () {
var _getAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(flagName) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(flagName in this.flags)) {
_context.next = 2;
break;
}
return _context.abrupt("return", this.flags[flagName]);
case 2:
_context.next = 4;
return this.evaluateFlag(flagName);
case 4:
this.flags[flagName] = _context.sent;
return _context.abrupt("return", this.flags[flagName]);
case 6:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getAsync(_x) {
return _getAsync.apply(this, arguments);
}
return getAsync;
}();
_proto.get = function get(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
var flagValue = this.evaluateFlag(flagName);
if (isPromise(flagValue)) {
throw new Error("Flag " + flagName + " cannot be synchronously evaluated. " + "Please use getAsync() instead.");
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
};
_proto.getNumber = function getNumber(flagName) {
return this.get(flagName);
};
_proto.getBool = function getBool(flagName) {
return this.get(flagName);
};
_proto.getFlags = function getFlags() {
return this.flags;
} // For backwards compatibility.
;
_proto.set = function set(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot set flag " + flagName + " as it has not been registered.");
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
};
_proto.evaluateFlag = function evaluateFlag(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found.");
}
return this.flagRegistry[flagName].evaluationFn();
};
_proto.setFlags = function setFlags(flags) {
this.flags = Object.assign({}, flags);
};
_proto.reset = function reset() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
};
_proto.populateURLFlags = function populateURLFlags() {
var _this = this;
if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') {
return;
}
var urlParams = this.getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(function (keyValue) {
var _keyValue$split = keyValue.split(':'),
key = _keyValue$split[0],
value = _keyValue$split[1];
_this.urlFlags[key] = parseValue(key, value);
});
}
};
_createClass(Environment, [{
key: "features",
get: function get() {
return this.flags;
}
}]);
return Environment;
}();
function getQueryParams(queryString) {
var params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) {
for (var _len = arguments.length, t = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
t[_key - 1] = arguments[_key];
}
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
value = value.toLowerCase();
if (value === 'true' || value === 'false') {
return value === 'true';
} else if ("" + +value === value) {
return +value;
}
throw new Error("Could not parse value flag value " + value + " for flag " + flagName + ".");
}
/**
* Returns the current environment (a global singleton).
*
* The environment object contains the evaluated feature values as well as the
* active platform.
*
* @doc {heading: 'Environment'}
*/
function env() {
return exports.ENV;
}
exports.ENV = null;
function setEnvironmentGlobal(environment) {
exports.ENV = environment;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Note that the identifier globalNameSpace is scoped to this module, but will
// always resolve to the same global object regardless of how the module is
// resolved.
// tslint:disable-next-line:no-any
var globalNameSpace; // tslint:disable-next-line:no-any
function getGlobalNamespace() {
if (globalNameSpace == null) {
// tslint:disable-next-line:no-any
var ns;
if (typeof window !== 'undefined') {
ns = window;
} else if (typeof global !== 'undefined') {
ns = global;
} else if (typeof process !== 'undefined') {
ns = process;
} else if (typeof self !== 'undefined') {
ns = self;
} else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
} // tslint:disable-next-line:no-any
function getGlobalMap() {
var ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
/**
* Returns a globally accessible 'singleton' object.
*
* @param key the name of the object
* @param init a function to initialize to initialize this object
* the first time it is fetched.
*/
function getGlobal(key, init) {
var globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
} else {
var singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
var Abs = 'Abs';
var Acos = 'Acos';
var Acosh = 'Acosh';
var Add = 'Add';
var AddN = 'AddN';
var All = 'All';
var Any = 'Any';
var ArgMax = 'ArgMax';
var ArgMin = 'ArgMin';
var Asin = 'Asin';
var Asinh = 'Asinh';
var Atan = 'Atan';
var Atanh = 'Atanh';
var Atan2 = 'Atan2';
var AvgPool = 'AvgPool';
var AvgPoolGrad = 'AvgPoolGrad';
var AvgPool3D = 'AvgPool3D';
var AvgPool3DGrad = 'AvgPool3DGrad';
var BatchMatMul = 'BatchMatMul';
var BatchToSpaceND = 'BatchToSpaceND';
var Bincount = 'Bincount';
var BroadcastTo = 'BroadcastTo';
var BroadcastArgs = 'BroadcastArgs';
var Cast = 'Cast';
var Ceil = 'Ceil';
var ClipByValue = 'ClipByValue';
var Complex = 'Complex';
var ComplexAbs = 'ComplexAbs';
var Concat = 'Concat';
var Conv2D = 'Conv2D';
var Conv2DBackpropFilter = 'Conv2DBackpropFilter';
var Conv2DBackpropInput = 'Conv2DBackpropInput';
var Conv3D = 'Conv3D';
var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
var Cos = 'Cos';
var Cosh = 'Cosh';
var Cumsum = 'Cumsum';
var CropAndResize = 'CropAndResize';
var DenseBincount = 'DenseBincount';
var DepthToSpace = 'DepthToSpace';
var DepthwiseConv2dNative = 'DepthwiseConv2dNative';
var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
var Diag = 'Diag';
var Dilation2D = 'Dilation2D';
var Dilation2DBackpropInput = 'Dilation2DBackpropInput';
var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
var RealDiv = 'RealDiv';
var Einsum = 'Einsum';
var Elu = 'Elu';
var EluGrad = 'EluGrad';
var Erf = 'Erf';
var Equal = 'Equal';
var Exp = 'Exp';
var ExpandDims = 'ExpandDims';
var Expm1 = 'Expm1';
var FFT = 'FFT';
var Fill = 'Fill';
var FlipLeftRight = 'FlipLeftRight';
var Floor = 'Floor';
var FloorDiv = 'FloorDiv';
var FusedBatchNorm = 'FusedBatchNorm';
var GatherV2 = 'GatherV2';
var GatherNd = 'GatherNd';
var Greater = 'Greater';
var GreaterEqual = 'GreaterEqual';
var Identity = 'Identity';
var IFFT = 'IFFT';
var Imag = 'Imag';
var IsFinite = 'IsFinite';
var IsInf = 'IsInf';
var IsNan = 'IsNan';
var LeakyRelu = 'LeakyRelu';
var Less = 'Less';
var LessEqual = 'LessEqual';
var LinSpace = 'LinSpace';
var Log = 'Log';
var Log1p = 'Log1p';
var LogicalAnd = 'LogicalAnd';
var LogicalNot = 'LogicalNot';
var LogicalOr = 'LogicalOr';
var LogSoftmax = 'LogSoftmax';
var LRN = 'LRN';
var LRNGrad = 'LRNGrad';
var Max = 'Max';
var Maximum = 'Maximum';
var MaxPool = 'MaxPool';
var MaxPoolGrad = 'MaxPoolGrad';
var MaxPool3D = 'MaxPool3D';
var MaxPool3DGrad = 'MaxPool3DGrad';
var MaxPoolWithArgmax = 'MaxPoolWithArgmax';
var Mean = 'Mean';
var Min = 'Min';
var Minimum = 'Minimum';
var MirrorPad = 'MirrorPad';
var Mod = 'Mod';
var Multinomial = 'Multinomial';
var Multiply = 'Multiply';
var Neg = 'Neg';
var NotEqual = 'NotEqual';
var NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
var NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
var NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
var OnesLike = 'OnesLike';
var OneHot = 'OneHot';
var Pack = 'Pack';
var PadV2 = 'PadV2';
var Pool = 'Pool';
var Pow = 'Pow';
var Prelu = 'Prelu';
var Prod = 'Prod';
var Range = 'Range';
var Real = 'Real';
var Reciprocal = 'Reciprocal';
var Relu = 'Relu';
var Reshape = 'Reshape';
var ResizeNearestNeighbor = 'ResizeNearestNeighbor';
var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
var ResizeBilinear = 'ResizeBilinear';
var ResizeBilinearGrad = 'ResizeBilinearGrad';
var Relu6 = 'Relu6';
var Reverse = 'Reverse';
var Round = 'Round';
var Rsqrt = 'Rsqrt';
var ScatterNd = 'ScatterNd';
var Select = 'Select';
var Selu = 'Selu';
var Slice = 'Slice';
var Sin = 'Sin';
var Sinh = 'Sinh';
var Sign = 'Sign';
var Sigmoid = 'Sigmoid';
var Softplus = 'Softplus';
var Sqrt = 'Sqrt';
var Sum = 'Sum';
var SpaceToBatchND = 'SpaceToBatchND';
var SplitV = 'SplitV';
var Softmax = 'Softmax';
var SparseFillEmptyRows = 'SparseFillEmptyRows';
var SparseReshape = 'SparseReshape';
var SparseSegmentMean = 'SparseSegmentMean';
var SparseSegmentSum = 'SparseSegmentSum';
var SparseToDense = 'SparseToDense';
var SquaredDifference = 'SquaredDifference';
var Square = 'Square';
var StridedSlice = 'StridedSlice';
var StringNGrams = 'StringNGrams';
var StringSplit = 'StringSplit';
var StringToHashBucketFast = 'StringToHashBucketFast';
var Sub = 'Sub';
var Tan = 'Tan';
var Tanh = 'Tanh';
var Tile = 'Tile';
var TopK = 'TopK';
var Transform = 'Transform';
var Transpose = 'Transpose';
var Unique = 'Unique';
var Unpack = 'Unpack';
var UnsortedSegmentSum = 'UnsortedSegmentSum';
var ZerosLike = 'ZerosLike';
/**
* TensorFlow.js-only kernels
*/
var Step = 'Step';
var FromPixels = 'FromPixels';
var RotateWithOffset = 'RotateWithOffset';
var _FusedMatMul = '_FusedMatMul';
var FusedConv2D = 'FusedConv2D';
var FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelRegistry = getGlobal('kernelRegistry', function () {
return new Map();
});
var gradRegistry = getGlobal('gradRegistry', function () {
return new Map();
});
/**
* Returns the kernel function (code) associated with the provided names.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*/
function getKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
/**
* Returns the registered gradient info associated with the provided kernel.
* @param kernelName The official TF kernel name.
*/
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
var it = kernelRegistry.entries();
var result = [];
while (true) {
var _it$next = it.next(),
done = _it$next.done,
value = _it$next.value;
if (done) {
break;
}
var key = value[0],
config = value[1];
var _key$split = key.split('_'),
backend = _key$split[0];
if (backend === backendName) {
result.push(config);
}
}
return result;
}
/**
* Registers the function (forward pass) for the kernel in a global registry.
*
* @param config A config object with the following properties:
* - `kernelName` The official name of the kernel.
* - `backendName` The official name of the backend.
* - `kernelFunc` The function to run during the forward pass of the kernel.
* - `setupFunc` Optional. Gets called once, after the backend initializes.
* - `disposeFunc` Optional. Gets called once, right before the backend is
* disposed.
*/
function registerKernel(config) {
var kernelName = config.kernelName,
backendName = config.backendName;
var key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
warn("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is already registered"));
}
kernelRegistry.set(key, config);
}
/**
* Registers a gradient function for a given kernel in the global registry,
* to be used during the back-propagation of that kernel.
*
* @param config An object with the following properties:
* - `kernelName` The name of the kernel that the gradient function is for.
* - `gradFunc` The function to run during back-propagation.
*/
function registerGradient(config) {
var kernelName = config.kernelName;
if (gradRegistry.has(kernelName)) {
// TODO (yassogba) after 3.0 assess whether we need to keep this gated
// to debug mode.
if (env().getBool('DEBUG')) {
warn("Overriding the gradient for '" + kernelName + "'");
}
}
gradRegistry.set(kernelName, config);
}
/**
* Removes the kernel function from the registry.
*
* @param kernelName The official name of the kernel.
* @param backendName The official name of the backend.
*
*/
function unregisterKernel(kernelName, backendName) {
var key = makeKey(kernelName, backendName);
if (!kernelRegistry.has(key)) {
throw new Error("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is not registered"));
}
kernelRegistry.delete(key);
}
/** Removes the registered gradient from the global registry. */
function unregisterGradient(kernelName) {
if (!gradRegistry.has(kernelName)) {
throw new Error("The gradient '" + kernelName + "' for backend is not registered");
}
gradRegistry.delete(kernelName);
}
/**
* Finds kernels that have already been registered to a backend and re-registers
* them for a new backend. Useful for registering custom backends.
* @param registeredBackendName Already registered backend.
* @param newBackendName New backend.
*/
function copyRegisteredKernels(registeredBackendName, newBackendName) {
var kernels = getKernelsForBackend(registeredBackendName);
kernels.forEach(function (kernelConfig) {
var newKernelConfig = Object.assign({}, kernelConfig, {
backendName: newBackendName
});
registerKernel(newKernelConfig);
});
}
function makeKey(kernelName, backendName) {
return backendName + "_" + kernelName;
}
var long_1 = Long;
/**
* wasm optimizations, to do native i64 multiplication and divide
*/
var wasm = null;
try {
wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11])), {}).exports;
} catch (e) {// no wasm support :(
}
/**
* Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers.
* See the from* functions below for more convenient ways of constructing Longs.
* @exports Long
* @class A Long class for representing a 64 bit two's-complement integer value.
* @param {number} low The low (signed) 32 bits of the long
* @param {number} high The high (signed) 32 bits of the long
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @constructor
*/
function Long(low, high, unsigned) {
/**
* The low 32 bits as a signed value.
* @type {number}
*/
this.low = low | 0;
/**
* The high 32 bits as a signed value.
* @type {number}
*/
this.high = high | 0;
/**
* Whether unsigned or not.
* @type {boolean}
*/
this.unsigned = !!unsigned;
} // The internal representation of a long is the two given signed, 32-bit values.
// We use 32-bit pieces because these are the size of integers on which
// Javascript performs bit-operations. For operations like addition and
// multiplication, we split each number into 16 bit pieces, which can easily be
// multiplied within Javascript's floating-point representation without overflow
// or change in sign.
//
// In the algorithms below, we frequently reduce the negative case to the
// positive case by negating the input(s) and then post-processing the result.
// Note that we must ALWAYS check specially whether those values are MIN_VALUE
// (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as
// a positive number, it overflows back into a negative). Not handling this
// case would often result in infinite recursion.
//
// Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from*
// methods on which they depend.
/**
* An indicator used to reliably determine if an object is a Long or not.
* @type {boolean}
* @const
* @private
*/
Long.prototype.__isLong__;
Object.defineProperty(Long.prototype, "__isLong__", {
value: true
});
/**
* @function
* @param {*} obj Object
* @returns {boolean}
* @inner
*/
function isLong(obj) {
return (obj && obj["__isLong__"]) === true;
}
/**
* Tests if the specified object is a Long.
* @function
* @param {*} obj Object
* @returns {boolean}
*/
Long.isLong = isLong;
/**
* A cache of the Long representations of small integer values.
* @type {!Object}
* @inner
*/
var INT_CACHE = {};
/**
* A cache of the Long representations of small unsigned integer values.
* @type {!Object}
* @inner
*/
var UINT_CACHE = {};
/**
* @param {number} value
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromInt(value, unsigned) {
var obj, cachedObj, cache;
if (unsigned) {
value >>>= 0;
if (cache = 0 <= value && value < 256) {
cachedObj = UINT_CACHE[value];
if (cachedObj) return cachedObj;
}
obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
if (cache) UINT_CACHE[value] = obj;
return obj;
} else {
value |= 0;
if (cache = -128 <= value && value < 128) {
cachedObj = INT_CACHE[value];
if (cachedObj) return cachedObj;
}
obj = fromBits(value, value < 0 ? -1 : 0, false);
if (cache) INT_CACHE[value] = obj;
return obj;
}
}
/**
* Returns a Long representing the given 32 bit integer value.
* @function
* @param {number} value The 32 bit integer in question
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long.fromInt = fromInt;
/**
* @param {number} value
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromNumber(value, unsigned) {
if (isNaN(value)) return unsigned ? UZERO : ZERO;
if (unsigned) {
if (value < 0) return UZERO;
if (value >= TWO_PWR_64_DBL) return MAX_UNSIGNED_VALUE;
} else {
if (value <= -TWO_PWR_63_DBL) return MIN_VALUE;
if (value + 1 >= TWO_PWR_63_DBL) return MAX_VALUE;
}
if (value < 0) return fromNumber(-value, unsigned).neg();
return fromBits(value % TWO_PWR_32_DBL | 0, value / TWO_PWR_32_DBL | 0, unsigned);
}
/**
* Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned.
* @function
* @param {number} value The number in question
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long.fromNumber = fromNumber;
/**
* @param {number} lowBits
* @param {number} highBits
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromBits(lowBits, highBits, unsigned) {
return new Long(lowBits, highBits, unsigned);
}
/**
* Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is
* assumed to use 32 bits.
* @function
* @param {number} lowBits The low 32 bits
* @param {number} highBits The high 32 bits
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long} The corresponding Long value
*/
Long.fromBits = fromBits;
/**
* @function
* @param {number} base
* @param {number} exponent
* @returns {number}
* @inner
*/
var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4)
/**
* @param {string} str
* @param {(boolean|number)=} unsigned
* @param {number=} radix
* @returns {!Long}
* @inner
*/
function fromString(str, unsigned, radix) {
if (str.length === 0) throw Error('empty string');
if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity") return ZERO;
if (typeof unsigned === 'number') {
// For goog.math.long compatibility
radix = unsigned, unsigned = false;
} else {
unsigned = !!unsigned;
}
radix = radix || 10;
if (radix < 2 || 36 < radix) throw RangeError('radix');
var p;
if ((p = str.indexOf('-')) > 0) throw Error('interior hyphen');else if (p === 0) {
return fromString(str.substring(1), unsigned, radix).neg();
} // Do several (8) digits each time through the loop, so as to
// minimize the calls to the very expensive emulated div.
var radixToPower = fromNumber(pow_dbl(radix, 8));
var result = ZERO;
for (var i = 0; i < str.length; i += 8) {
var size = Math.min(8, str.length - i),
value = parseInt(str.substring(i, i + size), radix);
if (size < 8) {
var power = fromNumber(pow_dbl(radix, size));
result = result.mul(power).add(fromNumber(value));
} else {
result = result.mul(radixToPower);
result = result.add(fromNumber(value));
}
}
result.unsigned = unsigned;
return result;
}
/**
* Returns a Long representation of the given string, written using the specified radix.
* @function
* @param {string} str The textual representation of the Long
* @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed
* @param {number=} radix The radix in which the text is written (2-36), defaults to 10
* @returns {!Long} The corresponding Long value
*/
Long.fromString = fromString;
/**
* @function
* @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val
* @param {boolean=} unsigned
* @returns {!Long}
* @inner
*/
function fromValue(val, unsigned) {
if (typeof val === 'number') return fromNumber(val, unsigned);
if (typeof val === 'string') return fromString(val, unsigned); // Throws for non-objects, converts non-instanceof Long:
return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
}
/**
* Converts the specified value to a Long using the appropriate from* function for its type.
* @function
* @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {!Long}
*/
Long.fromValue = fromValue; // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be
// no runtime penalty for these.
/**
* @type {number}
* @const
* @inner
*/
var TWO_PWR_16_DBL = 1 << 16;
/**
* @type {number}
* @const
* @inner
*/
var TWO_PWR_24_DBL = 1 << 24;
/**
* @type {number}
* @const
* @inner
*/
var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
/**
* @type {number}
* @const
* @inner
*/
var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
/**
* @type {number}
* @const
* @inner
*/
var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
/**
* @type {!Long}
* @const
* @inner
*/
var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
/**
* @type {!Long}
* @inner
*/
var ZERO = fromInt(0);
/**
* Signed zero.
* @type {!Long}
*/
Long.ZERO = ZERO;
/**
* @type {!Long}
* @inner
*/
var UZERO = fromInt(0, true);
/**
* Unsigned zero.
* @type {!Long}
*/
Long.UZERO = UZERO;
/**
* @type {!Long}
* @inner
*/
var ONE = fromInt(1);
/**
* Signed one.
* @type {!Long}
*/
Long.ONE = ONE;
/**
* @type {!Long}
* @inner
*/
var UONE = fromInt(1, true);
/**
* Unsigned one.
* @type {!Long}
*/
Long.UONE = UONE;
/**
* @type {!Long}
* @inner
*/
var NEG_ONE = fromInt(-1);
/**
* Signed negative one.
* @type {!Long}
*/
Long.NEG_ONE = NEG_ONE;
/**
* @type {!Long}
* @inner
*/
var MAX_VALUE = fromBits(0xFFFFFFFF | 0, 0x7FFFFFFF | 0, false);
/**
* Maximum signed value.
* @type {!Long}
*/
Long.MAX_VALUE = MAX_VALUE;
/**
* @type {!Long}
* @inner
*/
var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF | 0, 0xFFFFFFFF | 0, true);
/**
* Maximum unsigned value.
* @type {!Long}
*/
Long.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
/**
* @type {!Long}
* @inner
*/
var MIN_VALUE = fromBits(0, 0x80000000 | 0, false);
/**
* Minimum signed value.
* @type {!Long}
*/
Long.MIN_VALUE = MIN_VALUE;
/**
* @alias Long.prototype
* @inner
*/
var LongPrototype = Long.prototype;
/**
* Converts the Long to a 32 bit integer, assuming it is a 32 bit integer.
* @returns {number}
*/
LongPrototype.toInt = function toInt() {
return this.unsigned ? this.low >>> 0 : this.low;
};
/**
* Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa).
* @returns {number}
*/
LongPrototype.toNumber = function toNumber() {
if (this.unsigned) return (this.high >>> 0) * TWO_PWR_32_DBL + (this.low >>> 0);
return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
};
/**
* Converts the Long to a string written in the specified radix.
* @param {number=} radix Radix (2-36), defaults to 10
* @returns {string}
* @override
* @throws {RangeError} If `radix` is out of range
*/
LongPrototype.toString = function toString(radix) {
radix = radix || 10;
if (radix < 2 || 36 < radix) throw RangeError('radix');
if (this.isZero()) return '0';
if (this.isNegative()) {
// Unsigned Longs are never negative
if (this.eq(MIN_VALUE)) {
// We need to change the Long value before it can be negated, so we remove
// the bottom-most digit in this base and then recurse to do the rest.
var radixLong = fromNumber(radix),
div = this.div(radixLong),
rem1 = div.mul(radixLong).sub(this);
return div.toString(radix) + rem1.toInt().toString(radix);
} else return '-' + this.neg().toString(radix);
} // Do several (6) digits each time through the loop, so as to
// minimize the calls to the very expensive emulated div.
var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
rem = this;
var result = '';
while (true) {
var remDiv = rem.div(radixToPower),
intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
digits = intval.toString(radix);
rem = remDiv;
if (rem.isZero()) return digits + result;else {
while (digits.length < 6) {
digits = '0' + digits;
}
result = '' + digits + result;
}
}
};
/**
* Gets the high 32 bits as a signed integer.
* @returns {number} Signed high bits
*/
LongPrototype.getHighBits = function getHighBits() {
return this.high;
};
/**
* Gets the high 32 bits as an unsigned integer.
* @returns {number} Unsigned high bits
*/
LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
return this.high >>> 0;
};
/**
* Gets the low 32 bits as a signed integer.
* @returns {number} Signed low bits
*/
LongPrototype.getLowBits = function getLowBits() {
return this.low;
};
/**
* Gets the low 32 bits as an unsigned integer.
* @returns {number} Unsigned low bits
*/
LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
return this.low >>> 0;
};
/**
* Gets the number of bits needed to represent the absolute value of this Long.
* @returns {number}
*/
LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
if (this.isNegative()) // Unsigned Longs are never negative
return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
var val = this.high != 0 ? this.high : this.low;
for (var bit = 31; bit > 0; bit--) {
if ((val & 1 << bit) != 0) break;
}
return this.high != 0 ? bit + 33 : bit + 1;
};
/**
* Tests if this Long's value equals zero.
* @returns {boolean}
*/
LongPrototype.isZero = function isZero() {
return this.high === 0 && this.low === 0;
};
/**
* Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}.
* @returns {boolean}
*/
LongPrototype.eqz = LongPrototype.isZero;
/**
* Tests if this Long's value is negative.
* @returns {boolean}
*/
LongPrototype.isNegative = function isNegative() {
return !this.unsigned && this.high < 0;
};
/**
* Tests if this Long's value is positive.
* @returns {boolean}
*/
LongPrototype.isPositive = function isPositive() {
return this.unsigned || this.high >= 0;
};
/**
* Tests if this Long's value is odd.
* @returns {boolean}
*/
LongPrototype.isOdd = function isOdd() {
return (this.low & 1) === 1;
};
/**
* Tests if this Long's value is even.
* @returns {boolean}
*/
LongPrototype.isEven = function isEven() {
return (this.low & 1) === 0;
};
/**
* Tests if this Long's value equals the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.equals = function equals(other) {
if (!isLong(other)) other = fromValue(other);
if (this.unsigned !== other.unsigned && this.high >>> 31 === 1 && other.high >>> 31 === 1) return false;
return this.high === other.high && this.low === other.low;
};
/**
* Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.eq = LongPrototype.equals;
/**
* Tests if this Long's value differs from the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.notEquals = function notEquals(other) {
return !this.eq(
/* validates */
other);
};
/**
* Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.neq = LongPrototype.notEquals;
/**
* Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.ne = LongPrototype.notEquals;
/**
* Tests if this Long's value is less than the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.lessThan = function lessThan(other) {
return this.comp(
/* validates */
other) < 0;
};
/**
* Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.lt = LongPrototype.lessThan;
/**
* Tests if this Long's value is less than or equal the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
return this.comp(
/* validates */
other) <= 0;
};
/**
* Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.lte = LongPrototype.lessThanOrEqual;
/**
* Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.le = LongPrototype.lessThanOrEqual;
/**
* Tests if this Long's value is greater than the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.greaterThan = function greaterThan(other) {
return this.comp(
/* validates */
other) > 0;
};
/**
* Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.gt = LongPrototype.greaterThan;
/**
* Tests if this Long's value is greater than or equal the specified's.
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
return this.comp(
/* validates */
other) >= 0;
};
/**
* Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.gte = LongPrototype.greaterThanOrEqual;
/**
* Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}.
* @function
* @param {!Long|number|string} other Other value
* @returns {boolean}
*/
LongPrototype.ge = LongPrototype.greaterThanOrEqual;
/**
* Compares this Long's value with the specified's.
* @param {!Long|number|string} other Other value
* @returns {number} 0 if they are the same, 1 if the this is greater and -1
* if the given one is greater
*/
LongPrototype.compare = function compare(other) {
if (!isLong(other)) other = fromValue(other);
if (this.eq(other)) return 0;
var thisNeg = this.isNegative(),
otherNeg = other.isNegative();
if (thisNeg && !otherNeg) return -1;
if (!thisNeg && otherNeg) return 1; // At this point the sign bits are the same
if (!this.unsigned) return this.sub(other).isNegative() ? -1 : 1; // Both are positive if at least one is unsigned
return other.high >>> 0 > this.high >>> 0 || other.high === this.high && other.low >>> 0 > this.low >>> 0 ? -1 : 1;
};
/**
* Compares this Long's value with the specified's. This is an alias of {@link Long#compare}.
* @function
* @param {!Long|number|string} other Other value
* @returns {number} 0 if they are the same, 1 if the this is greater and -1
* if the given one is greater
*/
LongPrototype.comp = LongPrototype.compare;
/**
* Negates this Long's value.
* @returns {!Long} Negated Long
*/
LongPrototype.negate = function negate() {
if (!this.unsigned && this.eq(MIN_VALUE)) return MIN_VALUE;
return this.not().add(ONE);
};
/**
* Negates this Long's value. This is an alias of {@link Long#negate}.
* @function
* @returns {!Long} Negated Long
*/
LongPrototype.neg = LongPrototype.negate;
/**
* Returns the sum of this and the specified Long.
* @param {!Long|number|string} addend Addend
* @returns {!Long} Sum
*/
LongPrototype.add = function add(addend) {
if (!isLong(addend)) addend = fromValue(addend); // Divide each number into 4 chunks of 16 bits, and then sum the chunks.
var a48 = this.high >>> 16;
var a32 = this.high & 0xFFFF;
var a16 = this.low >>> 16;
var a00 = this.low & 0xFFFF;
var b48 = addend.high >>> 16;
var b32 = addend.high & 0xFFFF;
var b16 = addend.low >>> 16;
var b00 = addend.low & 0xFFFF;
var c48 = 0,
c32 = 0,
c16 = 0,
c00 = 0;
c00 += a00 + b00;
c16 += c00 >>> 16;
c00 &= 0xFFFF;
c16 += a16 + b16;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c32 += a32 + b32;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c48 += a48 + b48;
c48 &= 0xFFFF;
return fromBits(c16 << 16 | c00, c48 << 16 | c32, this.unsigned);
};
/**
* Returns the difference of this and the specified Long.
* @param {!Long|number|string} subtrahend Subtrahend
* @returns {!Long} Difference
*/
LongPrototype.subtract = function subtract(subtrahend) {
if (!isLong(subtrahend)) subtrahend = fromValue(subtrahend);
return this.add(subtrahend.neg());
};
/**
* Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}.
* @function
* @param {!Long|number|string} subtrahend Subtrahend
* @returns {!Long} Difference
*/
LongPrototype.sub = LongPrototype.subtract;
/**
* Returns the product of this and the specified Long.
* @param {!Long|number|string} multiplier Multiplier
* @returns {!Long} Product
*/
LongPrototype.multiply = function multiply(multiplier) {
if (this.isZero()) return ZERO;
if (!isLong(multiplier)) multiplier = fromValue(multiplier); // use wasm support if present
if (wasm) {
var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high);
return fromBits(low, wasm.get_high(), this.unsigned);
}
if (multiplier.isZero()) return ZERO;
if (this.eq(MIN_VALUE)) return multiplier.isOdd() ? MIN_VALUE : ZERO;
if (multiplier.eq(MIN_VALUE)) return this.isOdd() ? MIN_VALUE : ZERO;
if (this.isNegative()) {
if (multiplier.isNegative()) return this.neg().mul(multiplier.neg());else return this.neg().mul(multiplier).neg();
} else if (multiplier.isNegative()) return this.mul(multiplier.neg()).neg(); // If both longs are small, use float multiplication
if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24)) return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned); // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
// We can skip products that would overflow.
var a48 = this.high >>> 16;
var a32 = this.high & 0xFFFF;
var a16 = this.low >>> 16;
var a00 = this.low & 0xFFFF;
var b48 = multiplier.high >>> 16;
var b32 = multiplier.high & 0xFFFF;
var b16 = multiplier.low >>> 16;
var b00 = multiplier.low & 0xFFFF;
var c48 = 0,
c32 = 0,
c16 = 0,
c00 = 0;
c00 += a00 * b00;
c16 += c00 >>> 16;
c00 &= 0xFFFF;
c16 += a16 * b00;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c16 += a00 * b16;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c32 += a32 * b00;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c32 += a16 * b16;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c32 += a00 * b32;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
c48 &= 0xFFFF;
return fromBits(c16 << 16 | c00, c48 << 16 | c32, this.unsigned);
};
/**
* Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}.
* @function
* @param {!Long|number|string} multiplier Multiplier
* @returns {!Long} Product
*/
LongPrototype.mul = LongPrototype.multiply;
/**
* Returns this Long divided by the specified. The result is signed if this Long is signed or
* unsigned if this Long is unsigned.
* @param {!Long|number|string} divisor Divisor
* @returns {!Long} Quotient
*/
LongPrototype.divide = function divide(divisor) {
if (!isLong(divisor)) divisor = fromValue(divisor);
if (divisor.isZero()) throw Error('division by zero'); // use wasm support if present
if (wasm) {
// guard against signed division overflow: the largest
// negative number / -1 would be 1 larger than the largest
// positive number, due to two's complement.
if (!this.unsigned && this.high === -0x80000000 && divisor.low === -1 && divisor.high === -1) {
// be consistent with non-wasm code path
return this;
}
var low = (this.unsigned ? wasm.div_u : wasm.div_s)(this.low, this.high, divisor.low, divisor.high);
return fromBits(low, wasm.get_high(), this.unsigned);
}
if (this.isZero()) return this.unsigned ? UZERO : ZERO;
var approx, rem, res;
if (!this.unsigned) {
// This section is only relevant for signed longs and is derived from the
// closure library as a whole.
if (this.eq(MIN_VALUE)) {
if (divisor.eq(ONE) || divisor.eq(NEG_ONE)) return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE
else if (divisor.eq(MIN_VALUE)) return ONE;else {
// At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|.
var halfThis = this.shr(1);
approx = halfThis.div(divisor).shl(1);
if (approx.eq(ZERO)) {
return divisor.isNegative() ? ONE : NEG_ONE;
} else {
rem = this.sub(divisor.mul(approx));
res = approx.add(rem.div(divisor));
return res;
}
}
} else if (divisor.eq(MIN_VALUE)) return this.unsigned ? UZERO : ZERO;
if (this.isNegative()) {
if (divisor.isNegative()) return this.neg().div(divisor.neg());
return this.neg().div(divisor).neg();
} else if (divisor.isNegative()) return this.div(divisor.neg()).neg();
res = ZERO;
} else {
// The algorithm below has not been made for unsigned longs. It's therefore
// required to take special care of the MSB prior to running it.
if (!divisor.unsigned) divisor = divisor.toUnsigned();
if (divisor.gt(this)) return UZERO;
if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true
return UONE;
res = UZERO;
} // Repeat the following until the remainder is less than other: find a
// floating-point that approximates remainder / other *from below*, add this
// into the result, and subtract it from the remainder. It is critical that
// the approximate value is less than or equal to the real value so that the
// remainder never becomes negative.
rem = this;
while (rem.gte(divisor)) {
// Approximate the result of division. This may be a little greater or
// smaller than the actual value.
approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber())); // We will tweak the approximate result by changing it in the 48-th digit or
// the smallest non-fractional digit, whichever is larger.
var log2 = Math.ceil(Math.log(approx) / Math.LN2),
delta = log2 <= 48 ? 1 : pow_dbl(2, log2 - 48),
// Decrease the approximation until it is smaller than the remainder. Note
// that if it is too large, the product overflows and is negative.
approxRes = fromNumber(approx),
approxRem = approxRes.mul(divisor);
while (approxRem.isNegative() || approxRem.gt(rem)) {
approx -= delta;
approxRes = fromNumber(approx, this.unsigned);
approxRem = approxRes.mul(divisor);
} // We know the answer can't be zero... and actually, zero would cause
// infinite recursion since we would make no progress.
if (approxRes.isZero()) approxRes = ONE;
res = res.add(approxRes);
rem = rem.sub(approxRem);
}
return res;
};
/**
* Returns this Long divided by the specified. This is an alias of {@link Long#divide}.
* @function
* @param {!Long|number|string} divisor Divisor
* @returns {!Long} Quotient
*/
LongPrototype.div = LongPrototype.divide;
/**
* Returns this Long modulo the specified.
* @param {!Long|number|string} divisor Divisor
* @returns {!Long} Remainder
*/
LongPrototype.modulo = function modulo(divisor) {
if (!isLong(divisor)) divisor = fromValue(divisor); // use wasm support if present
if (wasm) {
var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(this.low, this.high, divisor.low, divisor.high);
return fromBits(low, wasm.get_high(), this.unsigned);
}
return this.sub(this.div(divisor).mul(divisor));
};
/**
* Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
* @function
* @param {!Long|number|string} divisor Divisor
* @returns {!Long} Remainder
*/
LongPrototype.mod = LongPrototype.modulo;
/**
* Returns this Long modulo the specified. This is an alias of {@link Long#modulo}.
* @function
* @param {!Long|number|string} divisor Divisor
* @returns {!Long} Remainder
*/
LongPrototype.rem = LongPrototype.modulo;
/**
* Returns the bitwise NOT of this Long.
* @returns {!Long}
*/
LongPrototype.not = function not() {
return fromBits(~this.low, ~this.high, this.unsigned);
};
/**
* Returns the bitwise AND of this Long and the specified.
* @param {!Long|number|string} other Other Long
* @returns {!Long}
*/
LongPrototype.and = function and(other) {
if (!isLong(other)) other = fromValue(other);
return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
};
/**
* Returns the bitwise OR of this Long and the specified.
* @param {!Long|number|string} other Other Long
* @returns {!Long}
*/
LongPrototype.or = function or(other) {
if (!isLong(other)) other = fromValue(other);
return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
};
/**
* Returns the bitwise XOR of this Long and the given one.
* @param {!Long|number|string} other Other Long
* @returns {!Long}
*/
LongPrototype.xor = function xor(other) {
if (!isLong(other)) other = fromValue(other);
return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
};
/**
* Returns this Long with bits shifted to the left by the given amount.
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shiftLeft = function shiftLeft(numBits) {
if (isLong(numBits)) numBits = numBits.toInt();
if ((numBits &= 63) === 0) return this;else if (numBits < 32) return fromBits(this.low << numBits, this.high << numBits | this.low >>> 32 - numBits, this.unsigned);else return fromBits(0, this.low << numBits - 32, this.unsigned);
};
/**
* Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}.
* @function
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shl = LongPrototype.shiftLeft;
/**
* Returns this Long with bits arithmetically shifted to the right by the given amount.
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shiftRight = function shiftRight(numBits) {
if (isLong(numBits)) numBits = numBits.toInt();
if ((numBits &= 63) === 0) return this;else if (numBits < 32) return fromBits(this.low >>> numBits | this.high << 32 - numBits, this.high >> numBits, this.unsigned);else return fromBits(this.high >> numBits - 32, this.high >= 0 ? 0 : -1, this.unsigned);
};
/**
* Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}.
* @function
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shr = LongPrototype.shiftRight;
/**
* Returns this Long with bits logically shifted to the right by the given amount.
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
if (isLong(numBits)) numBits = numBits.toInt();
numBits &= 63;
if (numBits === 0) return this;else {
var high = this.high;
if (numBits < 32) {
var low = this.low;
return fromBits(low >>> numBits | high << 32 - numBits, high >>> numBits, this.unsigned);
} else if (numBits === 32) return fromBits(high, 0, this.unsigned);else return fromBits(high >>> numBits - 32, 0, this.unsigned);
}
};
/**
* Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
* @function
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shru = LongPrototype.shiftRightUnsigned;
/**
* Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}.
* @function
* @param {number|!Long} numBits Number of bits
* @returns {!Long} Shifted Long
*/
LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
/**
* Converts this Long to signed.
* @returns {!Long} Signed long
*/
LongPrototype.toSigned = function toSigned() {
if (!this.unsigned) return this;
return fromBits(this.low, this.high, false);
};
/**
* Converts this Long to unsigned.
* @returns {!Long} Unsigned long
*/
LongPrototype.toUnsigned = function toUnsigned() {
if (this.unsigned) return this;
return fromBits(this.low, this.high, true);
};
/**
* Converts this Long to its byte representation.
* @param {boolean=} le Whether little or big endian, defaults to big endian
* @returns {!Array.<number>} Byte representation
*/
LongPrototype.toBytes = function toBytes(le) {
return le ? this.toBytesLE() : this.toBytesBE();
};
/**
* Converts this Long to its little endian byte representation.
* @returns {!Array.<number>} Little endian byte representation
*/
LongPrototype.toBytesLE = function toBytesLE() {
var hi = this.high,
lo = this.low;
return [lo & 0xff, lo >>> 8 & 0xff, lo >>> 16 & 0xff, lo >>> 24, hi & 0xff, hi >>> 8 & 0xff, hi >>> 16 & 0xff, hi >>> 24];
};
/**
* Converts this Long to its big endian byte representation.
* @returns {!Array.<number>} Big endian byte representation
*/
LongPrototype.toBytesBE = function toBytesBE() {
var hi = this.high,
lo = this.low;
return [hi >>> 24, hi >>> 16 & 0xff, hi >>> 8 & 0xff, hi & 0xff, lo >>> 24, lo >>> 16 & 0xff, lo >>> 8 & 0xff, lo & 0xff];
};
/**
* Creates a Long from its byte representation.
* @param {!Array.<number>} bytes Byte representation
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @param {boolean=} le Whether little or big endian, defaults to big endian
* @returns {Long} The corresponding Long value
*/
Long.fromBytes = function fromBytes(bytes, unsigned, le) {
return le ? Long.fromBytesLE(bytes, unsigned) : Long.fromBytesBE(bytes, unsigned);
};
/**
* Creates a Long from its little endian byte representation.
* @param {!Array.<number>} bytes Little endian byte representation
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {Long} The corresponding Long value
*/
Long.fromBytesLE = function fromBytesLE(bytes, unsigned) {
return new Long(bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24, bytes[4] | bytes[5] << 8 | bytes[6] << 16 | bytes[7] << 24, unsigned);
};
/**
* Creates a Long from its big endian byte representation.
* @param {!Array.<number>} bytes Big endian byte representation
* @param {boolean=} unsigned Whether unsigned or not, defaults to signed
* @returns {Long} The corresponding Long value
*/
Long.fromBytesBE = function fromBytesBE(bytes, unsigned) {
return new Long(bytes[4] << 24 | bytes[5] << 16 | bytes[6] << 8 | bytes[7], bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3], unsigned);
};
var LongExports = {
__proto__: null,
'default': long_1,
__moduleExports: long_1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Long$1 = // tslint:disable-next-line
long_1 || LongExports;
function hexToLong(hex) {
return Long$1.fromString(hex, true, 16);
} // Some primes between 2^63 and 2^64 for various uses.
// Hex 0xc3a5c85c97cb3127
var k0 = hexToLong('c3a5c85c97cb3127'); // Hex 0xb492b66fbe98f273
var k1 = hexToLong('b492b66fbe98f273'); // Hex 0x9ae16a3b2f90404f
var k2 = hexToLong('9ae16a3b2f90404f');
function shiftMix(val) {
return val.xor(val.shru(47));
}
function fetch$1(s, offset, numBytes) {
var bytes = s.slice(offset, offset + numBytes);
return Long$1.fromBytes(Array.from(bytes), true, true);
}
function fetch64(s, offset) {
return fetch$1(s, offset, 8);
}
function fetch32(s, offset) {
return fetch$1(s, offset, 4);
}
function rotate64(val, shift) {
// Avoid shifting by 64: doing so yields an undefined result.
return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
}
function hashLen16(u, v, mul) {
if (mul === void 0) {
mul = hexToLong('9ddfea08eb382d69');
}
// Murmur-inspired hashing.
var a = u.xor(v).mul(mul);
a = a.xor(a.shru(47));
var b = v.xor(a).mul(mul);
b = b.xor(b.shru(47));
b = b.mul(mul);
return b;
} // Return a 16-byte hash for 48 bytes. Quick and dirty.
// Callers do best to use "random-looking" values for a and b.
function weakHashLen32WithSeeds(w, x, y, z, a, b) {
a = a.add(w);
b = rotate64(b.add(a).add(z), 21);
var c = a;
a = a.add(x);
a = a.add(y);
b = b.add(rotate64(a, 44));
return [a.add(z), b.add(c)];
}
function weakHashLen32WithSeedsStr(s, offset, a, b) {
return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
}
function hashLen0to16(s, len) {
if (len === void 0) {
len = s.length;
}
if (len >= 8) {
var mul = k2.add(len * 2);
var a = fetch64(s, 0).add(k2);
var b = fetch64(s, len - 8);
var c = rotate64(b, 37).mul(mul).add(a);
var d = rotate64(a, 25).add(b).mul(mul);
return hashLen16(c, d, mul);
}
if (len >= 4) {
var _mul = k2.add(len * 2);
var _a = fetch32(s, 0);
return hashLen16(_a.shl(3).add(len), fetch32(s, len - 4), _mul);
}
if (len > 0) {
var _a2 = s[0];
var _b = s[len >> 1];
var _c = s[len - 1];
var y = _a2 + (_b << 8);
var z = len + (_c << 2);
return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
}
return k2;
}
function hashLen17to32(s, len) {
if (len === void 0) {
len = s.length;
}
var mul = k2.add(len * 2);
var a = fetch64(s, 0).mul(k1);
var b = fetch64(s, 8);
var c = fetch64(s, len - 8).mul(mul);
var d = fetch64(s, len - 16).mul(k2);
return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
}
function hashLen33to64(s, len) {
if (len === void 0) {
len = s.length;
}
var mul = k2.add(len * 2);
var a = fetch64(s, 0).mul(k2);
var b = fetch64(s, 8);
var c = fetch64(s, len - 8).mul(mul);
var d = fetch64(s, len - 16).mul(k2);
var y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
var z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
var e = fetch64(s, 16).mul(mul);
var f = fetch64(s, 24);
var g = y.add(fetch64(s, len - 32)).mul(mul);
var h = z.add(fetch64(s, len - 24)).mul(mul);
return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
}
function fingerPrint64(s, len) {
if (len === void 0) {
len = s.length;
}
var seed = Long$1.fromNumber(81, true);
if (len <= 32) {
if (len <= 16) {
return hashLen0to16(s, len);
} else {
return hashLen17to32(s, len);
}
} else if (len <= 64) {
return hashLen33to64(s, len);
} // For strings over 64 bytes we loop. Internal state consists of
// 56 bytes: v, w, x, y, and z.
var x = seed;
var y = seed.mul(k1).add(113);
var z = shiftMix(y.mul(k2).add(113)).mul(k2);
var v = [Long$1.UZERO, Long$1.UZERO];
var w = [Long$1.UZERO, Long$1.UZERO];
x = x.mul(k2).add(fetch64(s, 0));
var offset = 0; // Set end so that after the loop we have 1 to 64 bytes left to process.
var end = (len - 1 >> 6) * 64;
var last64 = end + (len - 1 & 63) - 63;
do {
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
x = x.xor(w[1]);
y = y.add(v[0]).add(fetch64(s, offset + 40));
z = rotate64(z.add(w[0]), 33).mul(k1);
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
var _ref = [x, z];
z = _ref[0];
x = _ref[1];
offset += 64;
} while (offset !== end);
var mul = k1.add(z.and(0xff).shl(1)); // Point to the last 64 bytes of input.
offset = last64;
w[0] = w[0].add(len - 1 & 63);
v[0] = v[0].add(w[0]);
w[0] = w[0].add(v[0]);
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
x = x.xor(w[1].mul(9));
y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
z = rotate64(z.add(w[0]), 33).mul(mul);
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
var _ref2 = [x, z];
z = _ref2[0];
x = _ref2[1];
return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Create typed array for scalar value. Used for storing in `DataStorage`.
*/
function createScalarValue(value, dtype) {
if (dtype === 'string') {
return encodeString(value);
}
return toTypedArray([value], dtype);
}
function noConversionNeeded(a, dtype) {
return a instanceof Float32Array && dtype === 'float32' || a instanceof Int32Array && dtype === 'int32' || a instanceof Uint8Array && dtype === 'bool';
}
function toTypedArray(a, dtype) {
if (dtype === 'string') {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = flatten(a);
}
if (env().getBool('DEBUG')) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(a);
} else if (dtype === 'int32') {
return new Int32Array(a);
} else if (dtype === 'bool') {
var bool = new Uint8Array(a.length);
for (var i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
} else {
throw new Error("Unknown data type " + dtype);
}
}
/**
* Returns the current high-resolution time in milliseconds relative to an
* arbitrary time in the past. It works across different platforms (node.js,
* browsers).
*
* ```js
* console.log(tf.util.now());
* ```
*
* @doc {heading: 'Util', namespace: 'util'}
*/
function now() {
return env().platform.now();
}
/**
* Returns a platform-specific implementation of
* [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
*
* If `fetch` is defined on the global object (`window`, `process`, etc.),
* `tf.util.fetch` returns that function.
*
* If not, `tf.util.fetch` returns a platform-specific solution.
*
* ```js
* const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs');
* // handle response
* ```
*
* @doc {heading: 'Util'}
*/
function fetch$2(path, requestInits) {
return env().platform.fetch(path, requestInits);
}
/**
* Encodes the provided string into bytes using the provided encoding scheme.
*
* @param s The string to encode.
* @param encoding The encoding scheme. Defaults to utf-8.
*
* @doc {heading: 'Util'}
*/
function encodeString(s, encoding) {
if (encoding === void 0) {
encoding = 'utf-8';
}
encoding = encoding || 'utf-8';
return env().platform.encode(s, encoding);
}
/**
* Decodes the provided bytes into a string using the provided encoding scheme.
* @param bytes The bytes to decode.
*
* @param encoding The encoding scheme. Defaults to utf-8.
*
* @doc {heading: 'Util'}
*/
function decodeString(bytes, encoding) {
if (encoding === void 0) {
encoding = 'utf-8';
}
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}
var util = {
__proto__: null,
createScalarValue: createScalarValue,
toTypedArray: toTypedArray,
now: now,
fetch: fetch$2,
encodeString: encodeString,
decodeString: decodeString,
shuffle: shuffle,
shuffleCombo: shuffleCombo,
clamp: clamp,
nearestLargerEven: nearestLargerEven,
swap: swap,
sum: sum,
randUniform: randUniform,
distSquared: distSquared,
assert: assert,
assertShapesMatch: assertShapesMatch,
assertNonNull: assertNonNull,
flatten: flatten,
sizeFromShape: sizeFromShape,
isScalarShape: isScalarShape,
arraysEqual: arraysEqual,
isInt: isInt,
tanh: tanh,
sizeToSquarishShape: sizeToSquarishShape,
createShuffledIndices: createShuffledIndices,
rightPad: rightPad,
repeatedTry: repeatedTry,
inferFromImplicitShape: inferFromImplicitShape,
parseAxisParam: parseAxisParam,
squeezeShape: squeezeShape,
getTypedArrayFromDType: getTypedArrayFromDType,
getArrayFromDType: getArrayFromDType,
checkConversionForErrors: checkConversionForErrors,
isValidDtype: isValidDtype,
hasEncodingLoss: hasEncodingLoss,
isTypedArray: isTypedArray$1,
bytesPerElement: bytesPerElement,
bytesFromStringArray: bytesFromStringArray,
isString: isString,
isBoolean: isBoolean,
isNumber: isNumber,
inferDtype: inferDtype,
isFunction: isFunction,
nearestDivisor: nearestDivisor,
computeStrides: computeStrides,
toNestedArray: toNestedArray,
makeOnesTypedArray: makeOnesTypedArray,
makeZerosTypedArray: makeZerosTypedArray,
makeZerosNestedTypedArray: makeZerosNestedTypedArray,
assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions,
locToIndex: locToIndex,
indexToLoc: indexToLoc,
isPromise: isPromise,
hexToLong: hexToLong,
fingerPrint64: fingerPrint64
};
var Profiler = /*#__PURE__*/function () {
function Profiler(backendTimer, logger) {
this.backendTimer = backendTimer;
this.logger = logger;
if (logger == null) {
this.logger = new Logger();
}
}
var _proto = Profiler.prototype;
_proto.profileKernel = function profileKernel(kernelName, inputs, f) {
var outputs;
var holdResultWrapperFn = function holdResultWrapperFn() {
outputs = f();
};
var timer;
var start = now();
if (this.backendTimer.timerAvailable()) {
timer = this.backendTimer.time(holdResultWrapperFn);
} else {
holdResultWrapperFn();
for (var _iterator = _createForOfIteratorHelperLoose(outputs), _step; !(_step = _iterator()).done;) {
var output = _step.value;
output.dataSync();
}
timer = Promise.resolve({
kernelMs: now() - start
});
}
if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
var _loop = function _loop(i) {
var output = outputs[i]; // Dangling promise here because we don't want to propagate up
// asynchronicity.
output.data().then(function (tensorVals) {
checkComputationForErrors(tensorVals, output.dtype, kernelName);
});
};
for (var i = 0; i < outputs.length; i++) {
_loop(i);
}
}
var kernelProfile = {
kernelName: kernelName,
outputs: outputs,
inputs: inputs,
timeMs: timer.then(function (timing) {
return timing.kernelMs;
}),
extraInfo: timer.then(function (timing) {
return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : '';
})
};
return kernelProfile;
};
_proto.logKernelProfile = function logKernelProfile(kernelProfile) {
var _this = this;
var kernelName = kernelProfile.kernelName,
outputs = kernelProfile.outputs,
timeMs = kernelProfile.timeMs,
inputs = kernelProfile.inputs,
extraInfo = kernelProfile.extraInfo;
outputs.forEach(function (result) {
Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) {
_this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
});
});
};
return Profiler;
}();
function checkComputationForErrors(vals, dtype, kernelName) {
if (dtype !== 'float32') {
// Only floating point computations will generate NaN values
return false;
}
for (var i = 0; i < vals.length; i++) {
var num = vals[i];
if (isNaN(num) || !isFinite(num)) {
// Throwing custom exception so behavior is testable.
console.warn("Found " + num + " in the result of '" + kernelName + "'");
return true;
}
}
return false;
}
var Logger = /*#__PURE__*/function () {
function Logger() {}
var _proto2 = Logger.prototype;
_proto2.logKernelProfile = function logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) : timeMs['error'];
var paddedName = rightPad(name, 25);
var rank = result.rank;
var size = result.size;
var shape = rightPad(result.shape.toString(), 14);
var inputShapesDescription = '';
for (var _name in inputs) {
var input = inputs[_name];
if (input != null) {
// The input might be a non-tensor (e.g HTMLImageElement), in which case
// we claim the output shape as input shape.
var inputShape = input.shape || result.shape;
var inputRank = inputShape.length;
inputShapesDescription += _name + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " ";
}
}
console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
};
return Logger;
}();
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes a list of TapeNodes that connect x to y, filtering everything else
* out and preserving the order of the original tape elements.
*
* @param tape The tape elements to filter.
* @param xs The input Tensors.
* @param y The output Tensor.
*/
function getFilteredNodesXToY(tape, xs, y) {
// Forward pass to compute all the nodes and Tensors that are transitively a
// function of x.
var tensorsFromX = {};
var nodesFromX = {};
for (var i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (var _i = 0; _i < tape.length; _i++) {
var node = tape[_i];
var nodeInputs = node.inputs;
for (var inputName in nodeInputs) {
var input = nodeInputs[inputName];
var anyInputFromX = false;
for (var j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
node.outputs.forEach(function (output) {
return tensorsFromX[output.id] = true;
});
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
} // Backward pass to find all of the nodes and Tensors that lead to y.
var tensorsLeadToY = {};
tensorsLeadToY[y.id] = true;
var nodesToY = {};
for (var _i2 = tape.length - 1; _i2 >= 0; _i2--) {
var _node = tape[_i2];
var _nodeInputs = _node.inputs; // If any of the outputs lead to y, mark all of the inputs as leading to y.
for (var _j = 0; _j < _node.outputs.length; _j++) {
if (tensorsLeadToY[_node.outputs[_j].id]) {
for (var _inputName in _nodeInputs) {
tensorsLeadToY[_nodeInputs[_inputName].id] = true;
nodesToY[_node.id] = true;
}
break;
}
}
} // Return the paths that come from x and lead to y.
var filteredTape = [];
for (var _i3 = 0; _i3 < tape.length; _i3++) {
var _node2 = tape[_i3];
if (nodesFromX[_node2.id] && nodesToY[_node2.id]) {
// Prune the inputs from the node that aren't a function of x.
var prunedInputs = {};
for (var _inputName2 in _node2.inputs) {
var nodeInput = _node2.inputs[_inputName2];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[_inputName2] = nodeInput;
}
} // Copy the node and overwrite inputsAndArgs to the pruned version.
var prunedNode = Object.assign({}, _node2);
prunedNode.inputs = prunedInputs;
prunedNode.outputs = _node2.outputs;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
/**
* Backpropagate gradients through the filtered TapeNodes.
*
* @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map
* is mutated by this method.
* @param filteredTape The filtered TapeNodes to backprop through.
*/
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
var _loop = function _loop(i) {
var node = filteredTape[i];
var dys = [];
node.outputs.forEach(function (o) {
var gradTensor = tensorAccumulatedGradientMap[o.id];
if (gradTensor != null) {
dys.push(gradTensor);
} else {
// This particular output is not in the back-propagation subgraph, so it
// does not affect the final output, thus we put null for its dy.
dys.push(null);
}
});
if (node.gradient == null) {
throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.kernelName + "."));
} // Backprop dy through this node and accumulate gradients over the inputs.
var inputGradients = node.gradient(dys);
var _loop2 = function _loop2(inputName) {
if (!(inputName in inputGradients)) {
throw new Error("Cannot backprop through input " + inputName + ". " + ("Available gradients found: " + Object.keys(inputGradients) + "."));
} // Call the gradient function.
var dx = tidy(function () {
return inputGradients[inputName]();
});
if (dx.dtype !== 'float32') {
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'"));
}
var x = node.inputs[inputName];
if (!arraysEqual(dx.shape, x.shape)) {
throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'"));
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
} else {
var curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
curGradient.dispose();
}
};
for (var inputName in node.inputs) {
_loop2(inputName);
}
};
// Walk the tape backward and keep a map of Tensor to its gradient.
for (var i = filteredTape.length - 1; i >= 0; i--) {
_loop(i);
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FORMAT_LIMIT_NUM_VALS = 20; // Number of first and last values to show when displaying a, b,...,y, z.
var FORMAT_NUM_FIRST_LAST_VALS = 3; // Number of significant digits to show.
var FORMAT_NUM_SIG_DIGITS = 7;
function tensorToString(vals, shape, dtype, verbose) {
var strides = computeStrides(shape);
var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
var rank = shape.length;
var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
var lines = ['Tensor'];
if (verbose) {
lines.push(" dtype: " + dtype);
lines.push(" rank: " + rank);
lines.push(" shape: [" + shape + "]");
lines.push(" values:");
}
lines.push(valsLines.map(function (l) {
return ' ' + l;
}).join('\n'));
return lines.join('\n');
}
function computeMaxSizePerColumn(vals, shape, dtype, strides) {
var n = sizeFromShape(shape);
var numCols = strides[strides.length - 1];
var padPerCol = new Array(numCols).fill(0);
var rank = shape.length;
var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
if (rank > 1) {
for (var row = 0; row < n / numCols; row++) {
var offset = row * numCols;
for (var j = 0; j < numCols; j++) {
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
}
}
}
return padPerCol;
}
function valToString(val, pad, dtype) {
var valStr;
if (Array.isArray(val)) {
valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j");
} else if (isString(val)) {
valStr = "'" + val + "'";
} else if (dtype === 'bool') {
valStr = boolNumToString(val);
} else {
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
}
return rightPad(valStr, pad);
}
function boolNumToString(v) {
return v === 0 ? 'false' : 'true';
}
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) {
if (isLast === void 0) {
isLast = true;
}
var storagePerElement = dtype === 'complex64' ? 2 : 1;
var size = shape[0];
var rank = shape.length;
if (rank === 0) {
if (dtype === 'complex64') {
var complexTuple = createComplexTuples(vals);
return [valToString(complexTuple[0], 0, dtype)];
}
if (dtype === 'bool') {
return [boolNumToString(vals[0])];
}
return [vals[0].toString()];
}
if (rank === 1) {
if (size > FORMAT_LIMIT_NUM_VALS) {
var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
var firstVals = Array.from(vals.slice(0, firstValsSize));
var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
if (dtype === 'complex64') {
firstVals = createComplexTuples(firstVals);
lastVals = createComplexTuples(lastVals);
}
return ['[' + firstVals.map(function (x, i) {
return valToString(x, padPerCol[i], dtype);
}).join(', ') + ', ..., ' + lastVals.map(function (x, i) {
return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype);
}).join(', ') + ']'];
}
var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : Array.from(vals);
return ['[' + displayVals.map(function (x, i) {
return valToString(x, padPerCol[i], dtype);
}).join(', ') + ']'];
} // The array is rank 2 or more.
var subshape = shape.slice(1);
var substrides = strides.slice(1);
var stride = strides[0] * storagePerElement;
var lines = [];
if (size > FORMAT_LIMIT_NUM_VALS) {
for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
var start = i * stride;
var end = start + stride;
lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false
/* isLast */
));
}
lines.push('...');
for (var _i = size - FORMAT_NUM_FIRST_LAST_VALS; _i < size; _i++) {
var _start = _i * stride;
var _end = _start + stride;
lines.push.apply(lines, subTensorToString(vals.slice(_start, _end), subshape, dtype, substrides, padPerCol, _i === size - 1
/* isLast */
));
}
} else {
for (var _i2 = 0; _i2 < size; _i2++) {
var _start2 = _i2 * stride;
var _end2 = _start2 + stride;
lines.push.apply(lines, subTensorToString(vals.slice(_start2, _end2), subshape, dtype, substrides, padPerCol, _i2 === size - 1
/* isLast */
));
}
}
var sep = rank === 2 ? ',' : '';
lines[0] = '[' + lines[0] + sep;
for (var _i3 = 1; _i3 < lines.length - 1; _i3++) {
lines[_i3] = ' ' + lines[_i3] + sep;
}
var newLineSep = ',\n';
for (var _i4 = 2; _i4 < rank; _i4++) {
newLineSep += '\n';
}
lines[lines.length - 1] = ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
return lines;
}
function createComplexTuples(vals) {
var complexTuples = [];
for (var i = 0; i < vals.length; i += 2) {
complexTuples.push([vals[i], vals[i + 1]]);
}
return complexTuples;
}
/**
* A mutable object, similar to `tf.Tensor`, that allows users to set values
* at locations before converting to an immutable `tf.Tensor`.
*
* See `tf.buffer` for creating a tensor buffer.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
var TensorBuffer = /*#__PURE__*/function () {
function TensorBuffer(shape, dtype, values) {
var _this = this;
this.dtype = dtype;
this.shape = shape.slice();
this.size = sizeFromShape(shape);
if (values != null) {
var n = values.length;
assert(n === this.size, function () {
return "Length of values '" + n + "' does not match the size " + ("inferred by the shape '" + _this.size + "'.");
});
}
if (dtype === 'complex64') {
throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + "a TensorBuffer for the real and imaginary parts separately and " + "call tf.complex(real, imag).");
}
this.values = values || getArrayFromDType(dtype, this.size);
this.strides = computeStrides(shape);
}
/**
* Sets a value in the buffer at a given location.
*
* @param value The value to set.
* @param locs The location indices.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
var _proto = TensorBuffer.prototype;
_proto.set = function set(value) {
var _this2 = this;
for (var _len = arguments.length, locs = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
locs[_key - 1] = arguments[_key];
}
if (locs.length === 0) {
locs = [0];
}
assert(locs.length === this.rank, function () {
return "The number of provided coordinates (" + locs.length + ") must " + ("match the rank (" + _this2.rank + ")");
});
var index = this.locToIndex(locs);
this.values[index] = value;
}
/**
* Returns the value in the buffer at the provided location.
*
* @param locs The location indices.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
;
_proto.get = function get() {
for (var _len2 = arguments.length, locs = new Array(_len2), _key2 = 0; _key2 < _len2; _key2++) {
locs[_key2] = arguments[_key2];
}
if (locs.length === 0) {
locs = [0];
}
var i = 0;
for (var _i = 0, _locs = locs; _i < _locs.length; _i++) {
var loc = _locs[_i];
if (loc < 0 || loc >= this.shape[i]) {
var msg = "Requested out of range element at " + locs + ". " + (" Buffer shape=" + this.shape);
throw new Error(msg);
}
i++;
}
var index = locs[locs.length - 1];
for (var _i2 = 0; _i2 < locs.length - 1; ++_i2) {
index += this.strides[_i2] * locs[_i2];
}
return this.values[index];
};
_proto.locToIndex = function locToIndex(locs) {
if (this.rank === 0) {
return 0;
} else if (this.rank === 1) {
return locs[0];
}
var index = locs[locs.length - 1];
for (var i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return index;
};
_proto.indexToLoc = function indexToLoc(index) {
if (this.rank === 0) {
return [];
} else if (this.rank === 1) {
return [index];
}
var locs = new Array(this.shape.length);
for (var i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / this.strides[i]);
index -= locs[i] * this.strides[i];
}
locs[locs.length - 1] = index;
return locs;
};
/**
* Creates an immutable `tf.Tensor` object from the buffer.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
_proto.toTensor = function toTensor() {
return trackerFn().makeTensor(this.values, this.shape, this.dtype);
};
_createClass(TensorBuffer, [{
key: "rank",
get: function get() {
return this.shape.length;
}
}]);
return TensorBuffer;
}(); // For tracking tensor creation and disposal.
var trackerFn = null; // Used by chaining methods to call into ops.
var opHandler = null; // Used to warn about deprecated methods.
var deprecationWarningFn = null; // This here so that we can use this method on dev branches and keep the
// functionality at master.
// tslint:disable-next-line:no-unused-expression
[deprecationWarningFn];
/**
* An external consumer can register itself as the tensor tracker. This way
* the Tensor class can notify the tracker for every tensor created and
* disposed.
*/
function setTensorTracker(fn) {
trackerFn = fn;
}
/**
* An external consumer can register itself as the op handler. This way the
* Tensor class can have chaining methods that call into ops via the op
* handler.
*/
function setOpHandler(handler) {
opHandler = handler;
}
/**
* Sets the deprecation warning function to be used by this file. This way the
* Tensor class can be a leaf but still use the environment.
*/
function setDeprecationWarningFn(fn) {
deprecationWarningFn = fn;
}
/**
* A `tf.Tensor` object represents an immutable, multidimensional array of
* numbers that has a shape and a data type.
*
* For performance reasons, functions that create tensors do not necessarily
* perform a copy of the data passed to them (e.g. if the data is passed as a
* `Float32Array`), and changes to the data will change the tensor. This is not
* a feature and is not supported. To avoid this behavior, use the tensor before
* changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`.
*
* See `tf.tensor` for details on how to create a `tf.Tensor`.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
var Tensor = /*#__PURE__*/function () {
function Tensor(shape, dtype, dataId, id) {
/** Whether this tensor has been globally kept. */
this.kept = false;
this.isDisposedInternal = false;
this.shape = shape.slice();
this.dtype = dtype || 'float32';
this.size = sizeFromShape(shape);
this.strides = computeStrides(shape);
this.dataId = dataId;
this.id = id;
this.rankType = this.rank < 5 ? this.rank.toString() : 'higher';
}
var _proto2 = Tensor.prototype;
/**
* Returns a promise of `tf.TensorBuffer` that holds the underlying data.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
_proto2.buffer =
/*#__PURE__*/
function () {
var _buffer = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var vals;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return this.data();
case 2:
vals = _context.sent;
return _context.abrupt("return", opHandler.buffer(this.shape, this.dtype, vals));
case 4:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function buffer() {
return _buffer.apply(this, arguments);
}
return buffer;
}()
/**
* Returns a `tf.TensorBuffer` that holds the underlying data.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.bufferSync = function bufferSync() {
return opHandler.buffer(this.shape, this.dtype, this.dataSync());
}
/**
* Returns the tensor data as a nested array. The transfer of data is done
* asynchronously.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.array =
/*#__PURE__*/
function () {
var _array = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var vals;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.data();
case 2:
vals = _context2.sent;
return _context2.abrupt("return", toNestedArray(this.shape, vals, this.dtype === 'complex64'));
case 4:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function array() {
return _array.apply(this, arguments);
}
return array;
}()
/**
* Returns the tensor data as a nested array. The transfer of data is done
* synchronously.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.arraySync = function arraySync() {
return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
}
/**
* Asynchronously downloads the values from the `tf.Tensor`. Returns a
* promise of `TypedArray` that resolves when the computation has finished.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.data =
/*#__PURE__*/
function () {
var _data = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var data, bytes;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
this.throwIfDisposed();
data = trackerFn().read(this.dataId);
if (!(this.dtype === 'string')) {
_context3.next = 13;
break;
}
_context3.next = 5;
return data;
case 5:
bytes = _context3.sent;
_context3.prev = 6;
return _context3.abrupt("return", bytes.map(function (b) {
return decodeString(b);
}));
case 10:
_context3.prev = 10;
_context3.t0 = _context3["catch"](6);
throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().');
case 13:
return _context3.abrupt("return", data);
case 14:
case "end":
return _context3.stop();
}
}
}, _callee3, this, [[6, 10]]);
}));
function data() {
return _data.apply(this, arguments);
}
return data;
}()
/**
* Synchronously downloads the values from the `tf.Tensor`. This blocks the
* UI thread until the values are ready, which can cause performance issues.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.dataSync = function dataSync() {
this.throwIfDisposed();
var data = trackerFn().readSync(this.dataId);
if (this.dtype === 'string') {
try {
return data.map(function (b) {
return decodeString(b);
});
} catch (_a) {
throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().');
}
}
return data;
}
/** Returns the underlying bytes of the tensor's data. */
;
_proto2.bytes =
/*#__PURE__*/
function () {
var _bytes = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4() {
var data;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
this.throwIfDisposed();
_context4.next = 3;
return trackerFn().read(this.dataId);
case 3:
data = _context4.sent;
if (!(this.dtype === 'string')) {
_context4.next = 8;
break;
}
return _context4.abrupt("return", data);
case 8:
return _context4.abrupt("return", new Uint8Array(data.buffer));
case 9:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function bytes() {
return _bytes.apply(this, arguments);
}
return bytes;
}()
/**
* Disposes `tf.Tensor` from memory.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.dispose = function dispose() {
if (this.isDisposed) {
return;
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
};
_proto2.throwIfDisposed = function throwIfDisposed() {
if (this.isDisposed) {
throw new Error("Tensor is disposed.");
}
}
/**
* Prints the `tf.Tensor`. See `tf.print` for details.
*
* @param verbose Whether to print verbose information about the tensor,
* including dtype and size.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.print = function print(verbose) {
if (verbose === void 0) {
verbose = false;
}
return opHandler.print(this, verbose);
}
/**
* Returns a copy of the tensor. See `tf.clone` for details.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.clone = function clone() {
this.throwIfDisposed();
return opHandler.clone(this);
}
/**
* Returns a human-readable description of the tensor. Useful for logging.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
;
_proto2.toString = function toString(verbose) {
if (verbose === void 0) {
verbose = false;
}
var vals = this.dataSync();
return tensorToString(vals, this.shape, this.dtype, verbose);
};
_proto2.cast = function cast(dtype) {
this.throwIfDisposed();
return opHandler.cast(this, dtype);
};
_proto2.variable = function variable(trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
this.throwIfDisposed();
return trackerFn().makeVariable(this, trainable, name, dtype);
};
_createClass(Tensor, [{
key: "rank",
get: function get() {
return this.shape.length;
}
}, {
key: "isDisposed",
get: function get() {
return this.isDisposedInternal;
}
}]);
return Tensor;
}();
Object.defineProperty(Tensor, Symbol.hasInstance, {
value: function value(instance) {
// Implementation note: we should use properties of the object that will be
// defined before the constructor body has finished executing (methods).
// This is because when this code is transpiled by babel, babel will call
// classCallCheck before the constructor body is run.
// See https://github.com/tensorflow/tfjs/issues/3384 for backstory.
return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null;
}
});
function getGlobalTensorClass() {
// Use getGlobal so that we can augment the Tensor class across package
// boundaries becase the node resolution alg may result in different modules
// being returned for this file depending on the path they are loaded from.
return getGlobal('Tensor', function () {
return Tensor;
});
} // Global side effect. Cache global reference to Tensor class
getGlobalTensorClass();
/**
* A mutable `tf.Tensor`, useful for persisting state, e.g. for training.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
var Variable = /*#__PURE__*/function (_Tensor) {
_inheritsLoose(Variable, _Tensor);
function Variable(initialValue, trainable, name, tensorId) {
var _this3;
_this3 = _Tensor.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this;
_this3.trainable = trainable;
_this3.name = name;
return _this3;
}
/**
* Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have
* the same shape and dtype as the old `tf.Tensor`.
*
* @param newValue New tensor to be assigned to this variable.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
var _proto3 = Variable.prototype;
_proto3.assign = function assign(newValue) {
if (newValue.dtype !== this.dtype) {
throw new Error("dtype of the new value (" + newValue.dtype + ") and " + ("previous value (" + this.dtype + ") must match"));
}
if (!arraysEqual(newValue.shape, this.shape)) {
throw new Error("shape of the new value (" + newValue.shape + ") and " + ("previous value (" + this.shape + ") must match"));
}
trackerFn().disposeTensor(this);
this.dataId = newValue.dataId;
trackerFn().incRef(this, null
/* backend */
);
};
_proto3.dispose = function dispose() {
trackerFn().disposeVariable(this);
this.isDisposedInternal = true;
};
return Variable;
}(Tensor);
Object.defineProperty(Variable, Symbol.hasInstance, {
value: function value(instance) {
return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function;
}
});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function (Rank) {
Rank["R0"] = "R0";
Rank["R1"] = "R1";
Rank["R2"] = "R2";
Rank["R3"] = "R3";
Rank["R4"] = "R4";
Rank["R5"] = "R5";
Rank["R6"] = "R6";
})(exports.Rank || (exports.Rank = {})); // Looks for upcasting types. Used, for example, in operations with mixed dtype
// inputs.
var UpcastInt32AndMap;
(function (UpcastInt32AndMap) {
UpcastInt32AndMap["float32"] = "float32";
UpcastInt32AndMap["int32"] = "int32";
UpcastInt32AndMap["bool"] = "int32";
UpcastInt32AndMap["complex64"] = "complex64";
})(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
var UpcastBoolAndMap;
(function (UpcastBoolAndMap) {
UpcastBoolAndMap["float32"] = "float32";
UpcastBoolAndMap["int32"] = "int32";
UpcastBoolAndMap["bool"] = "bool";
UpcastBoolAndMap["complex64"] = "complex64";
})(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
var UpcastFloat32AndMap;
(function (UpcastFloat32AndMap) {
UpcastFloat32AndMap["float32"] = "float32";
UpcastFloat32AndMap["int32"] = "float32";
UpcastFloat32AndMap["bool"] = "float32";
UpcastFloat32AndMap["complex64"] = "complex64";
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
var UpcastComplex64AndMap;
(function (UpcastComplex64AndMap) {
UpcastComplex64AndMap["float32"] = "complex64";
UpcastComplex64AndMap["int32"] = "complex64";
UpcastComplex64AndMap["bool"] = "complex64";
UpcastComplex64AndMap["complex64"] = "complex64";
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
var upcastTypeMap = {
'float32': UpcastFloat32AndMap,
'int32': UpcastInt32AndMap,
'bool': UpcastBoolAndMap,
'complex64': UpcastComplex64AndMap
};
function upcastType(typeA, typeB) {
if (typeA === 'string' || typeB === 'string') {
if (typeA === 'string' && typeB === 'string') {
return 'string';
}
throw new Error("Can not upcast " + typeA + " with " + typeB);
}
return upcastTypeMap[typeA][typeB];
}
/** Returns the output type after summation. */
function sumOutType(type) {
return upcastType(type, 'int32');
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function makeTypesMatch(a, b) {
if (a.dtype === b.dtype) {
return [a, b];
}
var dtype = upcastType(a.dtype, b.dtype);
return [a.cast(dtype), b.cast(dtype)];
}
function assertTypesMatch(a, b) {
assert(a.dtype === b.dtype, function () {
return "The dtypes of the first(" + a.dtype + ") and" + (" second(" + b.dtype + ") input must match");
});
}
function isTensorInList(tensor, tensorList) {
return tensorList.some(function (x) {
return x.id === tensor.id;
});
}
/**
* Extracts any `Tensor`s found within the provided object.
*
* @param container an object that may be a `Tensor` or may directly contain
* `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it
* is safe to pass any object here, except that `Promise`s are not
* supported.
* @returns An array of `Tensors` found within the passed object. If the
* argument is simply a `Tensor', a list containing that `Tensor` is
* returned. If the object is not a `Tensor` or does not
* contain `Tensors`, an empty list is returned.
*/
function getTensorsInContainer(result) {
var list = [];
var seen = new Set();
walkTensorContainer(result, list, seen);
return list;
}
function walkTensorContainer(container, list, seen) {
if (container == null) {
return;
}
if (container instanceof Tensor) {
list.push(container);
return;
}
if (!isIterable(container)) {
return;
} // Iteration over keys works also for arrays.
var iterable = container;
for (var k in iterable) {
var val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
} // tslint:disable-next-line:no-any
function isIterable(obj) {
return Array.isArray(obj) || typeof obj === 'object';
}
var tensor_util = {
__proto__: null,
makeTypesMatch: makeTypesMatch,
assertTypesMatch: assertTypesMatch,
isTensorInList: isTensorInList,
getTensorsInContainer: getTensorsInContainer
};
function isRegisteredKernelInvocation(kernelInvocation) {
return kernelInvocation.kernelName != null;
}
var EngineState = /*#__PURE__*/function () {
function EngineState() {
// Public since optimizers will use it.
this.registeredVariables = {};
this.nextTapeNodeId = 0;
this.numBytes = 0;
this.numTensors = 0;
this.numStringTensors = 0;
this.numDataBuffers = 0; // Number of nested tf.grad() statements when computing higher-order
// gradients. E.g. `1` for first-order gradients and `2` for second-order
// gradients. Used to track if the tape should be removed after a backprop.
this.gradientDepth = 0; // Number of nested kernel calls. When kernel depth is greater than 1, we turn
// off the tape.
this.kernelDepth = 0;
this.scopeStack = [];
/**
* Keeps track of the number of data moves during a kernel execution. We
* maintain a stack since kernels can call other kernels, recursively.
*/
this.numDataMovesStack = [];
this.nextScopeId = 0;
this.tensorInfo = new WeakMap();
this.profiling = false;
this.activeProfile = {
newBytes: 0,
newTensors: 0,
peakBytes: 0,
kernels: [],
result: null,
get kernelNames() {
return Array.from(new Set(this.kernels.map(function (k) {
return k.name;
})));
}
};
}
var _proto = EngineState.prototype;
_proto.dispose = function dispose() {
for (var variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
};
return EngineState;
}();
var Engine = /*#__PURE__*/function () {
function Engine(ENV) {
this.ENV = ENV;
this.registry = {};
this.registryFactory = {};
this.pendingBackendInitId = 0;
this.state = new EngineState();
}
var _proto2 = Engine.prototype;
_proto2.ready = /*#__PURE__*/function () {
var _ready = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var sortedBackends, i, backendName, success;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(this.pendingBackendInit != null)) {
_context.next = 2;
break;
}
return _context.abrupt("return", this.pendingBackendInit.then(function () {}));
case 2:
if (!(this.backendInstance != null)) {
_context.next = 4;
break;
}
return _context.abrupt("return");
case 4:
sortedBackends = this.getSortedBackends();
i = 0;
case 6:
if (!(i < sortedBackends.length)) {
_context.next = 18;
break;
}
backendName = sortedBackends[i];
_context.next = 10;
return this.initializeBackend(backendName).success;
case 10:
success = _context.sent;
if (!success) {
_context.next = 15;
break;
}
_context.next = 14;
return this.setBackend(backendName);
case 14:
return _context.abrupt("return");
case 15:
i++;
_context.next = 6;
break;
case 18:
throw new Error("Could not initialize any backends, all backend initializations " + "failed.");
case 19:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function ready() {
return _ready.apply(this, arguments);
}
return ready;
}();
_proto2.backendNames = function backendNames() {
return Object.keys(this.registryFactory);
};
_proto2.findBackend = function findBackend(backendName) {
if (!(backendName in this.registry)) {
// If the backend hasn't been initialized but we have a registry entry for
// it, initialize it and return it.
if (backendName in this.registryFactory) {
var _this$initializeBacke = this.initializeBackend(backendName),
asyncInit = _this$initializeBacke.asyncInit;
if (asyncInit) {
// Backend is not ready yet.
return null;
}
} else {
return null;
}
}
return this.registry[backendName];
};
_proto2.findBackendFactory = function findBackendFactory(backendName) {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
};
_proto2.registerBackend = function registerBackend(backendName, factory, priority) {
if (priority === void 0) {
priority = 1;
}
if (backendName in this.registryFactory) {
warn(backendName + " backend was already registered. " + "Reusing existing backend factory.");
return false;
}
this.registryFactory[backendName] = {
factory: factory,
priority: priority
};
return true;
};
_proto2.setBackend = /*#__PURE__*/function () {
var _setBackend = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(backendName) {
var _this$initializeBacke2, success, asyncInit, result;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (!(this.registryFactory[backendName] == null)) {
_context2.next = 2;
break;
}
throw new Error("Backend name '" + backendName + "' not found in registry");
case 2:
this.backendName = backendName;
if (!(this.registry[backendName] == null)) {
_context2.next = 16;
break;
}
this.backendInstance = null;
_this$initializeBacke2 = this.initializeBackend(backendName), success = _this$initializeBacke2.success, asyncInit = _this$initializeBacke2.asyncInit;
if (!asyncInit) {
_context2.next = 12;
break;
}
_context2.next = 9;
return success;
case 9:
_context2.t0 = _context2.sent;
_context2.next = 13;
break;
case 12:
_context2.t0 = success;
case 13:
result = _context2.t0;
if (result) {
_context2.next = 16;
break;
}
return _context2.abrupt("return", false);
case 16:
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels(); // Reset the profiler.
this.profiler = new Profiler(this.backendInstance);
return _context2.abrupt("return", true);
case 20:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setBackend(_x) {
return _setBackend.apply(this, arguments);
}
return setBackend;
}();
_proto2.setupRegisteredKernels = function setupRegisteredKernels() {
var _this = this;
var kernels = getKernelsForBackend(this.backendName);
kernels.forEach(function (kernel) {
if (kernel.setupFunc != null) {
kernel.setupFunc(_this.backendInstance);
}
});
};
_proto2.disposeRegisteredKernels = function disposeRegisteredKernels(backendName) {
var _this2 = this;
var kernels = getKernelsForBackend(backendName);
kernels.forEach(function (kernel) {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(_this2.registry[backendName]);
}
});
}
/**
* Initializes a backend by looking up the backend name in the factory
* registry and calling the factory method. Returns a boolean representing
* whether the initialization of the backend suceeded. Throws an error if
* there is no backend in the factory registry.
*/
;
_proto2.initializeBackend = function initializeBackend(backendName) {
var _this3 = this;
var registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error("Cannot initialize backend " + backendName + ", no registration found.");
}
try {
var backend = registryFactoryEntry.factory();
/* Test if the factory returns a promise.
Done in a more liberal way than
previous 'Promise.resolve(backend)===backend'
as we needed to account for custom Promise
implementations (e.g. Angular) */
if (backend && !(backend instanceof KernelBackend) && typeof backend.then === 'function') {
var promiseId = ++this.pendingBackendInitId;
var success = backend.then(function (backendInstance) {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < _this3.pendingBackendInitId) {
return false;
}
_this3.registry[backendName] = backendInstance;
_this3.pendingBackendInit = null;
return true;
}).catch(function (err) {
// Outdated promise. Another backend was set in the meantime.
if (promiseId < _this3.pendingBackendInitId) {
return false;
}
_this3.pendingBackendInit = null;
warn("Initialization of backend " + backendName + " failed");
warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return {
success: success,
asyncInit: true
};
} else {
this.registry[backendName] = backend;
return {
success: true,
asyncInit: false
};
}
} catch (err) {
warn("Initialization of backend " + backendName + " failed");
warn(err.stack || err.message);
return {
success: false,
asyncInit: false
};
}
};
_proto2.removeBackend = function removeBackend(backendName) {
if (!(backendName in this.registryFactory)) {
throw new Error(backendName + " backend not found in registry");
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
// There is a pending promise of the backend we want to remove. Make it
// obsolete.
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName]; // Unset the backend if it is active.
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
};
_proto2.getSortedBackends = function getSortedBackends() {
var _this4 = this;
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort(function (a, b) {
// Highest priority comes first.
return _this4.registryFactory[b].priority - _this4.registryFactory[a].priority;
});
};
_proto2.initializeBackendsAndReturnBest = function initializeBackendsAndReturnBest() {
var sortedBackends = this.getSortedBackends();
for (var i = 0; i < sortedBackends.length; i++) {
var backendName = sortedBackends[i];
var _this$initializeBacke3 = this.initializeBackend(backendName),
success = _this$initializeBacke3.success,
asyncInit = _this$initializeBacke3.asyncInit;
if (asyncInit || success) {
return {
name: backendName,
asyncInit: asyncInit
};
}
}
throw new Error("Could not initialize any backends, all backend initializations " + "failed.");
};
_proto2.moveData = function moveData(backend, dataId) {
var info = this.state.tensorInfo.get(dataId);
var srcBackend = info.backend;
var values = this.readSync(dataId);
var refCount = srcBackend.refCount(dataId); // Delete the tensor from the old backend and move it to the new
// backend.
srcBackend.disposeData(dataId, true);
info.backend = backend;
backend.move(dataId, values, info.shape, info.dtype, refCount);
if (this.shouldCheckForMemLeaks()) {
// Track the number of moves during a kernel execution to correctly
// detect memory leaks.
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
};
_proto2.tidy = function tidy(nameOrFn, fn) {
var _this5 = this;
var name = null;
if (fn == null) {
// Called with only 1 argument.
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
} else {
// Called with 2 arguments.
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error('When calling with two arguments, the first argument ' + 'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error('When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function');
}
name = nameOrFn; // TODO(nsthorat,smilkov): Do operation logging and performance
// profiling.
}
var result;
return this.scopedRun(function () {
return _this5.startScope(name);
}, function () {
return _this5.endScope(result);
}, function () {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
};
_proto2.scopedRun = function scopedRun(start, end, f) {
start();
try {
var res = f();
end();
return res;
} catch (ex) {
end();
throw ex;
}
};
_proto2.nextTensorId = function nextTensorId() {
return Engine.nextTensorId++;
};
_proto2.nextVariableId = function nextVariableId() {
return Engine.nextVariableId++;
}
/**
* This method is called instead of the public-facing tensor.clone() when
* saving a tensor for backwards pass. It makes sure to add the clone
* operation to the tape regardless of being called inside a kernel
* execution.
*/
;
_proto2.clone = function clone(x) {
var y = ENGINE.runKernel(Identity, {
x: x
});
var inputs = {
x: x
};
var grad = function grad(dy) {
return {
x: function x() {
var dtype = 'float32';
var gradInputs = {
x: dy
};
var attrs = {
dtype: dtype
};
return ENGINE.runKernel(Cast, gradInputs, // tslint:disable-next-line: no-unnecessary-type-assertion
attrs);
}
};
};
var saved = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}
/**
* Execute a kernel with the given name and return the output tensor.
*
* @param kernelName The name of the kernel to execute.
* @param inputs A map of input names to tensors.
* @param attrs A map of attribute names to their values. An attribute is a
* primitive (non-tensor) input to the kernel.
* @param inputsToSave A list of tensors, inputs to save for the backprop
* computation.
* @param outputsToSave A list of booleans, specifying which output to save
* for the backprop computation. These are booleans since the output
* tensors are not visible to the user.
*/
;
_proto2.runKernel = function runKernel(kernelName, inputs, attrs) {
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
var hasKernel = getKernel(kernelName, this.backendName) != null;
if (!hasKernel) {
throw new Error("Kernel '" + kernelName + "' not registered for backend '" + this.backendName + "'");
}
return this.runKernelFunc({
kernelName: kernelName,
inputs: inputs,
attrs: attrs
});
};
_proto2.shouldCheckForMemLeaks = function shouldCheckForMemLeaks() {
return this.ENV.getBool('IS_TEST');
};
_proto2.checkKernelForMemLeak = function checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
var numDataIdsAfter = this.backend.numDataIds(); // Count the number of data ids associated with the result of the kernel.
var numOutputDataIds = 0;
outInfos.forEach(function (info) {
// Complex numbers allocate 3 data ids, one for 'real', one for
// 'imaginary', and one for the container that holds the former two.
numOutputDataIds += info.dtype === 'complex64' ? 3 : 1;
}); // Account for the number of moves during kernel execution. A "data move"
// can happen in the middle of a kernel execution, placing a new (key,value)
// pair in the data storage. Since data moves have net zero effect (we
// always remove the data from the old backend), we have to cancel them out
// when detecting memory leaks.
var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'"));
}
}
/**
* Internal helper method to execute a kernel Func
*
* Use `runKernel` to execute kernels from outside of engine.
*/
;
_proto2.runKernelFunc = function runKernelFunc(kernelParams) {
var _this6 = this;
var outputs;
var saved = [];
var isTapeOn = this.isTapeOn();
var startingBytecount = this.state.numBytes;
var startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
var kernelFunc;
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
var out;
var kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ? kernelParams.kernelName : this.state.activeScope != null ? this.state.activeScope.name : ''; // Create the kernelFunc from either a registered kernel OR passed in
// forward/backward functions (used by custom grad). In this context a
// kernelFunc wraps a kernel implementation with some bookkeeping.
if (isRegisteredKernelInvocation(kernelParams)) {
var kernelName = kernelParams.kernelName,
_inputs = kernelParams.inputs,
_attrs = kernelParams.attrs;
if (this.backendName == null) {
// backend has not been initialized yet (backend initialization is lazy
// can be deferred until an op/ kernel is run).
// The below getter has side effects that will try to initialize the
// backend and set properties like this.backendName
// tslint:disable-next-line: no-unused-expression
this.backend;
}
var kernel = getKernel(kernelName, this.backendName);
assert(kernel != null, function () {
return "Cannot find registered kernel '" + kernelName + "' for backend '" + _this6.backendName + "'";
});
kernelFunc = function kernelFunc() {
var numDataIdsBefore = _this6.backend.numDataIds();
out = kernel.kernelFunc({
inputs: _inputs,
attrs: _attrs,
backend: _this6.backend
});
var outInfos = Array.isArray(out) ? out : [out];
if (_this6.shouldCheckForMemLeaks()) {
_this6.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
var outTensors = outInfos.map(function (outInfo) {
// todo (yassogba) remove this option (Tensor) when node backend
// methods have been modularized and they all return tensorInfo.
// TensorInfos do not have a rank attribute.
if (outInfo.rank != null) {
return outInfo;
}
var dataId = outInfo.dataId,
shape = outInfo.shape,
dtype = outInfo.dtype;
return _this6.makeTensorFromDataId(dataId, shape, dtype);
}); // Save any required inputs and outputs.
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since there would be no backprop for these tensors
// (which would otherwise dispose them).
if (isTapeOn) {
var tensorsToSave = _this6.getTensorsForGradient(kernelName, _inputs, outTensors);
saved = _this6.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
} else {
var forwardFunc = kernelParams.forwardFunc; // Running a customGrad op.
var saveFunc = function saveFunc(tensors) {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(function (tensor) {
return _this6.keep(_this6.clone(tensor));
});
};
kernelFunc = function kernelFunc() {
var numDataIdsBefore = _this6.backend.numDataIds();
out = _this6.tidy(function () {
return forwardFunc(_this6.backend, saveFunc);
});
var outs = Array.isArray(out) ? out : [out];
if (_this6.shouldCheckForMemLeaks()) {
// Scope name is used to print a more helpful error message if needed.
_this6.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
}
return outs;
};
} //
// Run the kernelFunc. Optionally profiling it.
//
var inputs = kernelParams.inputs,
attrs = kernelParams.attrs;
var backwardsFunc = isRegisteredKernelInvocation(kernelParams) ? null : kernelParams.backwardsFunc;
var kernelProfile;
this.scopedRun( // Stop recording to a tape when running a kernel.
function () {
return _this6.state.kernelDepth++;
}, function () {
return _this6.state.kernelDepth--;
}, function () {
if (!_this6.ENV.getBool('DEBUG') && !_this6.state.profiling) {
outputs = kernelFunc();
} else {
kernelProfile = _this6.profiler.profileKernel(kernelOrScopeName, inputs, function () {
return kernelFunc();
});
if (_this6.ENV.getBool('DEBUG')) {
_this6.profiler.logKernelProfile(kernelProfile);
}
outputs = kernelProfile.outputs;
}
});
if (isTapeOn) {
this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelOrScopeName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(function (key) {
return inputs[key] != null ? inputs[key].shape : null;
}),
outputShapes: outputs.map(function (item) {
return item.shape;
}),
kernelTimeMs: kernelProfile.timeMs,
extraInfo: kernelProfile.extraInfo
});
}
return Array.isArray(out) ? outputs : outputs[0];
}
/**
* Saves tensors used in forward mode for use in backward mode.
*
* @param tensors the list of tensors to save.
*/
;
_proto2.saveTensorsForBackwardMode = function saveTensorsForBackwardMode(tensors) {
var _this7 = this;
var saved = tensors.map(function (tensor) {
return _this7.keep(_this7.clone(tensor));
});
return saved;
}
/**
* Returns a list of tensors to save for a given gradient calculation.
*
* @param kernelName name of kernel to look up gradient for.
* @param inputs a map of input tensors.
* @param outputs an array of output tensors from forward mode of kernel.
*/
;
_proto2.getTensorsForGradient = function getTensorsForGradient(kernelName, inputs, outputs) {
var gradConfig = getGradient(kernelName);
if (gradConfig != null) {
var inputsToSave = gradConfig.inputsToSave || [];
var outputsToSave = gradConfig.outputsToSave || []; // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs
// specified in inputsToSave will be saved.
var inputTensorsToSave;
if (gradConfig.saveAllInputs) {
assert(Array.isArray(inputs), function () {
return 'saveAllInputs is true, expected inputs to be an array.';
});
inputTensorsToSave = Object.keys(inputs).map(function (key) {
return inputs[key];
});
} else {
inputTensorsToSave = inputsToSave.map(function (inputName) {
return inputs[inputName];
});
}
var outputTensorsToSave = outputs.filter(function (_, i) {
return outputsToSave[i];
});
return inputTensorsToSave.concat(outputTensorsToSave);
} // We return an empty list rather than throw an error because the kernel we
// are looking up may not actually be relevant to backproping through the
// overall function
//
// See 'does not error if irrelevant (pruned) ops are missing grads' test
// in gradients_test.ts for an example.
return [];
}
/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always
* creates a new data id and writes the values to the underlying backend.
*/
;
_proto2.makeTensor = function makeTensor(values, shape, dtype, backend) {
if (values == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
dtype = dtype || 'float32';
backend = backend || this.backend;
var backendVals = values;
if (dtype === 'string' && isString(values[0])) {
backendVals = values.map(function (d) {
return encodeString(d);
});
}
var dataId = backend.write(backendVals, shape, dtype);
var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend); // Count bytes for string tensors.
if (dtype === 'string') {
var info = this.state.tensorInfo.get(dataId);
var newBytes = bytesFromStringArray(backendVals);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
}
/**
* Internal method used by backends. Makes a new tensor
* that is a wrapper around an existing data id. It doesn't create
* a new data id, only increments the ref count used in memory tracking.
*/
;
_proto2.makeTensorFromDataId = function makeTensorFromDataId(dataId, shape, dtype, backend) {
dtype = dtype || 'float32';
var t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
return t;
};
_proto2.makeVariable = function makeVariable(initialValue, trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.cast(dtype);
}
var v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error("Variable with name " + v.name + " was already registered");
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
};
_proto2.trackTensor = function trackTensor(a, backend) {
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
} // Bytes for complex numbers are counted by their components. Bytes for
// string tensors are counted when writing values.
var bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * bytesPerElement(a.dtype);
}
this.state.numBytes += bytes;
if (!this.state.tensorInfo.has(a.dataId)) {
this.state.numDataBuffers++;
this.state.tensorInfo.set(a.dataId, {
backend: backend || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes: bytes
});
}
if (!(a instanceof Variable)) {
this.track(a);
}
} // Track the tensor by dataId and increase the refCount for the dataId in the
// backend.
// TODO(pyu10055): This is currently used by makeVariable method, to increase
// refCount on the backend for the dataId. It can potentially be replaced with
// Identity op indead of calling backend directly.
;
_proto2.incRef = function incRef(a, backend) {
this.trackTensor(a, backend);
this.backend.incRef(a.dataId);
};
_proto2.removeDataId = function removeDataId(dataId, backend) {
if (this.state.tensorInfo.has(dataId) && this.state.tensorInfo.get(dataId).backend === backend) {
this.state.tensorInfo.delete(dataId);
this.state.numDataBuffers--;
}
};
_proto2.disposeTensor = function disposeTensor(a) {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
var info = this.state.tensorInfo.get(a.dataId);
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
this.state.numBytes -= info.bytes;
} // Don't count bytes for complex numbers as they are counted by their
// components.
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
var bytes = a.size * bytesPerElement(a.dtype);
this.state.numBytes -= bytes;
} // Remove the reference to dataId if backend dispose the data successfully
if (info.backend.disposeData(a.dataId)) {
this.removeDataId(a.dataId, info.backend);
} // TODO(nsthorat): Construct an error and save the stack trace for
// debugging when in debug mode. Creating a stack trace is too expensive
// to do unconditionally.
};
_proto2.disposeVariables = function disposeVariables() {
for (var varName in this.state.registeredVariables) {
var v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
};
_proto2.disposeVariable = function disposeVariable(v) {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
};
_proto2.memory = function memory() {
var info = this.backend.memory();
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push('Memory usage by string tensors is approximate ' + '(2 bytes per character)');
}
return info;
};
_proto2.profile = /*#__PURE__*/function () {
var _profile = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(query) {
var startBytes, startNumTensors, _iterator, _step, kernel;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
this.state.profiling = true;
startBytes = this.state.numBytes;
startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
_context3.next = 6;
return query();
case 6:
this.state.activeProfile.result = _context3.sent;
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) {
return d.totalBytesSnapshot;
}));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors;
_iterator = _createForOfIteratorHelperLoose(this.state.activeProfile.kernels);
case 12:
if ((_step = _iterator()).done) {
_context3.next = 22;
break;
}
kernel = _step.value;
_context3.next = 16;
return kernel.kernelTimeMs;
case 16:
kernel.kernelTimeMs = _context3.sent;
_context3.next = 19;
return kernel.extraInfo;
case 19:
kernel.extraInfo = _context3.sent;
case 20:
_context3.next = 12;
break;
case 22:
return _context3.abrupt("return", this.state.activeProfile);
case 23:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function profile(_x2) {
return _profile.apply(this, arguments);
}
return profile;
}();
_proto2.isTapeOn = function isTapeOn() {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
};
_proto2.addTapeNode = function addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
var _this8 = this;
var tapeNode = {
id: this.state.nextTapeNodeId++,
kernelName: kernelName,
inputs: inputs,
outputs: outputs,
saved: saved
};
var gradConfig = getGradient(kernelName);
if (gradConfig != null) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = function (dys) {
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
// the backprop graph to the user as null instead of zeros
dys = dys.map(function (dy, i) {
if (dy == null) {
var output = outputs[i];
var vals = makeZerosTypedArray(output.size, output.dtype);
return _this8.makeTensor(vals, output.shape, output.dtype);
}
return dy;
}); // Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
};
_proto2.keep = function keep(result) {
result.kept = true;
return result;
};
_proto2.startTape = function startTape() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
};
_proto2.endTape = function endTape() {
this.state.gradientDepth--;
}
/**
* Start a scope. Use this with endScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
;
_proto2.startScope = function startScope(name) {
var scopeInfo = {
track: [],
name: 'unnamed scope',
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
}
/**
* End a scope. Use this with startScope() to achieve the same functionality
* as scope() without the need for a function closure.
*/
;
_proto2.endScope = function endScope(result) {
var _this9 = this;
var tensorsToTrackInParent = getTensorsInContainer(result);
var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) {
return t.id;
})); // Dispose the arrays tracked in this scope.
for (var i = 0; i < this.state.activeScope.track.length; i++) {
var tensor = this.state.activeScope.track[i];
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
tensor.dispose();
}
}
var oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; // Track the current result in the parent scope.
tensorsToTrackInParent.forEach(function (tensor) {
// Only track the tensor if was allocated in the inner scope and is not
// globally kept.
if (!tensor.kept && tensor.scopeId === oldScope.id) {
_this9.track(tensor);
}
});
}
/**
* Returns gradients of `f` with respect to each of the `xs`. The gradients
* returned are of the same length as `xs`, but some might be null if `f`
* was not a function of that `x`. It also takes optional dy to multiply the
* gradient, which defaults to `1`.
*/
;
_proto2.gradients = function gradients(f, xs, dy, allowNoGradients) {
var _this10 = this;
if (allowNoGradients === void 0) {
allowNoGradients = false;
}
assert(xs.length > 0, function () {
return 'gradients() received an empty list of xs.';
});
if (dy != null && dy.dtype !== 'float32') {
throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'");
}
var y = this.scopedRun(function () {
return _this10.startTape();
}, function () {
return _this10.endTape();
}, function () {
return _this10.tidy('forward', f);
});
assert(y instanceof Tensor, function () {
return 'The result y returned by f() must be a tensor.';
}); // Filter out the nodes that don't connect x => y.
var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.');
}
return this.tidy('backward', function () {
var accumulatedGradientMap = {};
accumulatedGradientMap[y.id] = dy == null ? ones(y.shape) : dy; // Backprop gradients through the filtered nodes.
backpropagateGradients(accumulatedGradientMap, filteredTape, // Pass the tidy function to avoid circular dep with `tape.ts`.
function (f) {
return _this10.tidy(f);
}, // Pass an add function to avoide a circular dep with `tape.ts`.
add);
var grads = xs.map(function (x) {
return accumulatedGradientMap[x.id];
});
if (_this10.state.gradientDepth === 0) {
// This means that we are not computing higher-order gradients
// and can clean up the tape.
_this10.state.activeTape.forEach(function (node) {
for (var _iterator2 = _createForOfIteratorHelperLoose(node.saved), _step2; !(_step2 = _iterator2()).done;) {
var tensor = _step2.value;
tensor.dispose();
}
});
_this10.state.activeTape = null;
}
return {
value: y,
grads: grads
};
});
};
_proto2.customGrad = function customGrad(f) {
var _this11 = this;
assert(isFunction(f), function () {
return 'The f passed in customGrad(f) must be a function.';
});
return function () {
for (var _len = arguments.length, inputs = new Array(_len), _key = 0; _key < _len; _key++) {
inputs[_key] = arguments[_key];
}
assert(inputs.every(function (t) {
return t instanceof Tensor;
}), function () {
return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors';
});
var res;
var inputMap = {};
inputs.forEach(function (input, i) {
inputMap[i] = input;
});
var forwardFunc = function forwardFunc(_, save) {
res = f.apply(void 0, [].concat(inputs, [save]));
assert(res.value instanceof Tensor, function () {
return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor';
});
assert(isFunction(res.gradFunc), function () {
return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.';
});
return res.value;
};
var backwardsFunc = function backwardsFunc(dy, saved) {
var gradRes = res.gradFunc(dy, saved);
var grads = Array.isArray(gradRes) ? gradRes : [gradRes];
assert(grads.length === inputs.length, function () {
return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).';
});
assert(grads.every(function (t) {
return t instanceof Tensor;
}), function () {
return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.';
});
var gradMap = {};
grads.forEach(function (grad, i) {
gradMap[i] = function () {
return grad;
};
});
return gradMap;
};
return _this11.runKernelFunc({
forwardFunc: forwardFunc,
backwardsFunc: backwardsFunc,
inputs: inputMap
});
};
};
_proto2.readSync = function readSync(dataId) {
// Route the read to the correct backend.
var info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
};
_proto2.read = function read(dataId) {
// Route the read to the correct backend.
var info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
};
_proto2.time = /*#__PURE__*/function () {
var _time = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(query) {
var start, timingInfo;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
start = now();
_context4.next = 3;
return this.backend.time(query);
case 3:
timingInfo = _context4.sent;
timingInfo.wallMs = now() - start;
return _context4.abrupt("return", timingInfo);
case 6:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function time(_x3) {
return _time.apply(this, arguments);
}
return time;
}()
/**
* Tracks a Tensor in the current scope to be automatically cleaned up
* when the current scope ends, and returns the value.
*
* @param result The Tensor to track in the current scope.
*/
;
_proto2.track = function track(result) {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
};
/**
* Resets the engine state. Removes all backends but does not remove
* registered backend factories.
*/
_proto2.reset = function reset() {
// Make any pending promise obsolete.
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (var backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
};
_createClass(Engine, [{
key: "backend",
get: function get() {
if (this.pendingBackendInit != null) {
throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " + "sure to await tf.ready() or await tf.setBackend() before calling " + "other methods");
}
if (this.backendInstance == null) {
var _this$initializeBacke4 = this.initializeBackendsAndReturnBest(),
name = _this$initializeBacke4.name,
asyncInit = _this$initializeBacke4.asyncInit;
if (asyncInit) {
throw new Error("The highest priority backend '" + name + "' has not yet been " + "initialized. Make sure to await tf.ready() or " + "await tf.setBackend() before calling other methods");
}
this.setBackend(name);
}
return this.backendInstance;
}
}, {
key: "registeredVariables",
get: function get() {
return this.state.registeredVariables;
}
}]);
return Engine;
}();
Engine.nextTensorId = 0;
Engine.nextVariableId = 0;
function ones(shape) {
var values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
return ENGINE.makeTensor(values, shape, 'float32');
}
function getOrMakeEngine() {
var ns = getGlobalNamespace();
if (ns._tfengine == null) {
var environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV); // Tell the current tensor interface that the global engine is responsible
// for tracking.
setTensorTracker(function () {
return ns._tfengine;
});
return ns._tfengine;
}
var ENGINE = getOrMakeEngine();
/**
* A implementation of the add op for use within engine and tape.
*
* This allows us to avoid a circular dependency between add.ts and engine.
* It is exported to be available in tape tests.
*/
function add(a, b) {
// We duplicate Add here to avoid a circular dependency with add.ts.
var inputs = {
a: a,
b: b
};
return ENGINE.runKernel(Add, inputs);
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line:no-any
function _isNavigatorDefined() {
return typeof navigator !== 'undefined' && navigator != null;
}
function isMobile(nav) {
if (nav || _isNavigatorDefined()) {
if (!nav) {
nav = navigator;
}
if (nav.product === 'ReactNative') {
return true;
} // tslint:disable-next-line:no-any
var a = nav.userAgent || nav.vendor || (typeof window !== 'undefined' ? window.opera : ''); // Use `navigator.userAgentData.mobile` as fallback.
if (!a) {
// tslint:disable-next-line:no-any
var navAny = nav;
return navAny.userAgentData && navAny.userAgentData.mobile;
} // tslint:disable-next-line:max-line-length
return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a) || // tslint:disable-next-line:max-line-length
/1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0, 4));
}
return false;
}
function isBrowser() {
return typeof window !== 'undefined' && window.document != null || //@ts-ignore
typeof WorkerGlobalScope !== 'undefined';
}
var device_util = {
__proto__: null,
isMobile: isMobile,
isBrowser: isBrowser
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ENV = env();
/**
* This file contains environment-related flag registrations.
*/
/** Whether to enable debug mode. */
ENV.registerFlag('DEBUG', function () {
return false;
}, function (debugValue) {
if (debugValue) {
console.warn('Debugging mode is ON. The output of every math call will ' + 'be downloaded to CPU and checked for NaNs. ' + 'This significantly impacts performance.');
}
});
/** Whether we are in a browser (as versus, say, node.js) environment. */
ENV.registerFlag('IS_BROWSER', function () {
return isBrowser();
});
/** Whether we are in a browser (as versus, say, node.js) environment. */
ENV.registerFlag('IS_NODE', function () {
return typeof process !== 'undefined' && typeof process.versions !== 'undefined' && typeof process.versions.node !== 'undefined';
});
/** Whether this browser is Chrome. */
ENV.registerFlag('IS_CHROME', function () {
return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor);
});
/**
* True when the environment is "production" where we disable safety checks
* to gain performance.
*/
ENV.registerFlag('PROD', function () {
return false;
});
/**
* Whether to do sanity checks when inferring a shape from user-provided
* values, used when creating a new tensor.
*/
ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () {
return ENV.getBool('DEBUG');
});
/** Whether deprecation warnings are enabled. */
ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () {
return true;
});
/** True if running unit tests. */
ENV.registerFlag('IS_TEST', function () {
return false;
});
/** Whether to check computation result for errors. */
ENV.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', function () {
return true;
});
/** Whether the backend needs to wrap input to imageBitmap. */
ENV.registerFlag('WRAP_TO_IMAGEBITMAP', function () {
return false;
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function inferShape(val, dtype) {
var firstElem = val;
if (isTypedArray$1(val)) {
return dtype === 'string' ? [] : [val.length];
}
if (!Array.isArray(val)) {
return []; // Scalar.
}
var shape = [];
while (Array.isArray(firstElem) || isTypedArray$1(firstElem) && dtype !== 'string') {
shape.push(firstElem.length);
firstElem = firstElem[0];
}
if (Array.isArray(val) && env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
deepAssertShapeConsistency(val, shape, []);
}
return shape;
}
function deepAssertShapeConsistency(val, shape, indices) {
indices = indices || [];
if (!Array.isArray(val) && !isTypedArray$1(val)) {
assert(shape.length === 0, function () {
return "Element arr[" + indices.join('][') + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements");
});
return;
}
assert(shape.length > 0, function () {
return "Element arr[" + indices.join('][') + "] should be a primitive, " + ("but is an array of " + val.length + " elements");
});
assert(val.length === shape[0], function () {
return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements");
});
var subShape = shape.slice(1);
for (var i = 0; i < val.length; ++i) {
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
}
}
function assertDtype(expectedDtype, actualDType, argName, functionName) {
if (expectedDtype === 'string_or_numeric') {
return;
}
if (expectedDtype == null) {
throw new Error("Expected dtype cannot be null.");
}
if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') {
throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor"));
}
}
function convertToTensor(x, argName, functionName, parseAsDtype) {
if (parseAsDtype === void 0) {
parseAsDtype = 'numeric';
}
if (x instanceof Tensor) {
assertDtype(parseAsDtype, x.dtype, argName, functionName);
return x;
}
var inferredDtype = inferDtype(x); // If the user expects a bool/int/float, use that info to update the
// inferredDtype when it is not a string.
if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
inferredDtype = parseAsDtype;
}
assertDtype(parseAsDtype, inferredDtype, argName, functionName);
if (x == null || !isTypedArray$1(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string') {
var type = x == null ? 'null' : x.constructor.name;
throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'"));
}
var inferredShape = inferShape(x, inferredDtype);
if (!isTypedArray$1(x) && !Array.isArray(x)) {
x = [x];
}
var skipTypedArray = true;
var values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray);
return ENGINE.makeTensor(values, inferredShape, inferredDtype);
}
function convertToTensorArray(arg, argName, functionName, parseAsDtype) {
if (parseAsDtype === void 0) {
parseAsDtype = 'numeric';
}
if (!Array.isArray(arg)) {
throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + '`Tensor[]` or `TensorLike[]`');
}
var tensors = arg;
return tensors.map(function (t, i) {
return convertToTensor(t, argName + "[" + i + "]", functionName, parseAsDtype);
});
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OP_SCOPE_SUFFIX = '__op';
/**
* Used for wrapping functions that perform math operations on
* Tensors. The function will be wrapped in a named scope that cleans all
* memory usage after the function is done.
*/
function op(f) {
var keys = Object.keys(f);
if (keys.length !== 1) {
throw new Error("Please provide an object with a single key " + "(operation name) mapping to a function. Got an object with " + (keys.length + " keys."));
}
var opName = keys[0];
var fn = f[opName]; // Strip the underscore from the end of the function name.
if (opName.endsWith('_')) {
opName = opName.substring(0, opName.length - 1);
} // add an __op suffix to distinguish ops from kernels in tf.profile
opName = opName + OP_SCOPE_SUFFIX; // tslint:disable-next-line:no-any
var f2 = function f2() {
ENGINE.startScope(opName);
try {
var result = fn.apply(void 0, arguments);
if (isPromise(result)) {
console.error('Cannot return a Promise inside of tidy.');
}
ENGINE.endScope(result);
return result;
} catch (ex) {
ENGINE.endScope(null);
throw ex;
}
};
Object.defineProperty(f2, 'name', {
value: opName,
configurable: true
}); // tslint:disable-next-line:no-any
return f2;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts two real numbers to a complex number.
*
* Given a tensor `real` representing the real part of a complex number, and a
* tensor `imag` representing the imaginary part of a complex number, this
* operation returns complex numbers elementwise of the form [r0, i0, r1, i1],
* where r represents the real part and i represents the imag part.
*
* The input tensors real and imag must have the same shape.
*
* ```js
* const real = tf.tensor1d([2.25, 3.25]);
* const imag = tf.tensor1d([4.75, 5.75]);
* const complex = tf.complex(real, imag);
*
* complex.print();
* ```
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function complex_(real, imag) {
var $real = convertToTensor(real, 'real', 'complex');
var $imag = convertToTensor(imag, 'imag', 'complex');
assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " + "must match in call to tf.complex().");
var inputs = {
real: $real,
imag: $imag
};
return ENGINE.runKernel(Complex, inputs);
}
var complex = op({
complex_: complex_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/** This is shared code across all tensor creation methods. */
function makeTensor(values, shape, inferredShape, dtype) {
if (dtype == null) {
dtype = inferDtype(values);
}
if (dtype === 'complex64') {
throw new Error("Cannot construct a complex64 tensor directly. " + "Please use tf.complex(real, imag).");
}
if (!isTypedArray$1(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') {
throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + 'an array of numbers/booleans/strings, or a TypedArray');
}
if (shape != null) {
assertNonNegativeIntegerDimensions(shape);
var providedSize = sizeFromShape(shape);
var inferredSize = sizeFromShape(inferredShape);
assert(providedSize === inferredSize, function () {
return "Based on the provided shape, [" + shape + "], the tensor should have " + (providedSize + " values but has " + inferredSize);
});
for (var i = 0; i < inferredShape.length; ++i) {
var inferred = inferredShape[i];
var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true;
assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () {
return "Error creating a new Tensor. Inferred shape " + ("(" + inferredShape + ") does not match the provided ") + ("shape (" + shape + "). ");
});
}
}
if (!isTypedArray$1(values) && !Array.isArray(values)) {
values = [values];
}
shape = shape || inferredShape;
values = dtype !== 'string' ? toTypedArray(values, dtype) : flatten(values, [], true);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with the provided values, shape and dtype.
*
* ```js
* // Pass an array of values to create a vector.
* tf.tensor([1, 2, 3, 4]).print();
* ```
*
* ```js
* // Pass a nested array of values to make a matrix or a higher
* // dimensional tensor.
* tf.tensor([[1, 2], [3, 4]]).print();
* ```
*
* ```js
* // Pass a flat array and specify a shape yourself.
* tf.tensor([1, 2, 3, 4], [2, 2]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`. If the values are strings,
* they will be encoded as utf-8 and kept as `Uint8Array[]`.
* @param shape The shape of the tensor. Optional. If not provided,
* it is inferred from `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor(values, shape, dtype) {
var inferredShape = inferShape(values, dtype);
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/* Type definitions for exporting and importing of models. */
/**
* A map from Tensor dtype to number of bytes per element of the Tensor.
*/
var DTYPE_VALUE_SIZE_MAP = {
'float32': 4,
'float16': 2,
'int32': 4,
'uint16': 2,
'uint8': 1,
'bool': 1,
'complex64': 8
};
/** Number of bytes reserved for the length of the string. (32bit integer). */
var NUM_BYTES_STRING_LENGTH = 4;
/**
* Encode a map from names to weight values as an ArrayBuffer, along with an
* `Array` of `WeightsManifestEntry` as specification of the encoded weights.
*
* This function does not perform sharding.
*
* This function is the reverse of `decodeWeights`.
*
* @param tensors A map ("dict") from names to tensors.
* @param group Group to which the weights belong (optional).
* @returns A `Promise` of
* - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
* concatenated.
* - An `Array` of `WeightManifestEntry`s, carrying information including
* tensor names, `dtype`s and shapes.
* @throws Error: on unsupported tensor `dtype`.
*/
function encodeWeights(_x, _x2) {
return _encodeWeights.apply(this, arguments);
}
/**
* Decode flat ArrayBuffer as weights.
*
* This function does not handle sharding.
*
* This function is the reverse of `encodeWeights`.
*
* @param buffer A flat ArrayBuffer carrying the binary values of the tensors
* concatenated in the order specified in `specs`.
* @param specs Specifications of the names, dtypes and shapes of the tensors
* whose value are encoded by `buffer`.
* @return A map from tensor name to tensor value, with the names corresponding
* to names in `specs`.
* @throws Error, if any of the tensors has unsupported dtype.
*/
function _encodeWeights() {
_encodeWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(tensors, group) {
var specs, dataPromises, names, _loop, i, tensorValues;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
// TODO(adarob, cais): Support quantization.
specs = [];
dataPromises = [];
names = Array.isArray(tensors) ? tensors.map(function (tensor) {
return tensor.name;
}) : Object.keys(tensors);
_loop = function _loop(i) {
var name = names[i];
var t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && t.dtype !== 'string' && t.dtype !== 'complex64') {
throw new Error("Unsupported dtype in weight '" + name + "': " + t.dtype);
}
var spec = {
name: name,
shape: t.shape,
dtype: t.dtype
};
if (t.dtype === 'string') {
var utf8bytes = new Promise( /*#__PURE__*/function () {
var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(resolve) {
var vals, totalNumBytes, bytes, offset, _i6, val, bytesOfLength;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return t.bytes();
case 2:
vals = _context.sent;
totalNumBytes = vals.reduce(function (p, c) {
return p + c.length;
}, 0) + NUM_BYTES_STRING_LENGTH * vals.length;
bytes = new Uint8Array(totalNumBytes);
offset = 0;
for (_i6 = 0; _i6 < vals.length; _i6++) {
val = vals[_i6];
bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
bytes.set(bytesOfLength, offset);
offset += NUM_BYTES_STRING_LENGTH;
bytes.set(val, offset);
offset += val.length;
}
resolve(bytes);
case 8:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return function (_x5) {
return _ref.apply(this, arguments);
};
}());
dataPromises.push(utf8bytes);
} else {
dataPromises.push(t.data());
}
if (group != null) {
spec.group = group;
}
specs.push(spec);
};
for (i = 0; i < names.length; ++i) {
_loop(i);
}
_context2.next = 7;
return Promise.all(dataPromises);
case 7:
tensorValues = _context2.sent;
return _context2.abrupt("return", {
data: concatenateTypedArrays(tensorValues),
specs: specs
});
case 9:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
return _encodeWeights.apply(this, arguments);
}
function decodeWeights(buffer, specs) {
// TODO(adarob, cais): Support quantization.
var out = {};
var float16Decode;
var offset = 0;
for (var _iterator = _createForOfIteratorHelperLoose(specs), _step; !(_step = _iterator()).done;) {
var spec = _step.value;
var name = spec.name;
var dtype = spec.dtype;
var shape = spec.shape;
var size = sizeFromShape(shape);
var values = void 0;
if ('quantization' in spec) {
var quantization = spec.quantization;
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
if (!('min' in quantization && 'scale' in quantization)) {
throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " " + "doesn't have corresponding metadata min and scale.");
}
} else if (quantization.dtype === 'float16') {
if (dtype !== 'float32') {
throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " + ("which only supports weights of type float32 not " + dtype + "."));
}
} else {
throw new Error("Weight " + spec.name + " has unknown " + ("quantization dtype " + quantization.dtype + ". ") + "Supported quantization dtypes are: " + "'uint8', 'uint16', and 'float16'.");
}
var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
var byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor);
var quantizedArray = quantization.dtype === 'uint8' ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer);
if (dtype === 'float32') {
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
values = new Float32Array(quantizedArray.length);
for (var i = 0; i < quantizedArray.length; i++) {
var v = quantizedArray[i];
values[i] = v * quantization.scale + quantization.min;
}
} else if (quantization.dtype === 'float16') {
if (float16Decode === undefined) {
float16Decode = getFloat16Decoder();
}
values = float16Decode(quantizedArray);
} else {
throw new Error("Unsupported quantization type " + quantization.dtype + " " + "for weight type float32.");
}
} else if (dtype === 'int32') {
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error("Unsupported quantization type " + quantization.dtype + " " + "for weight type int32.");
}
values = new Int32Array(quantizedArray.length);
for (var _i = 0; _i < quantizedArray.length; _i++) {
var _v = quantizedArray[_i];
values[_i] = Math.round(_v * quantization.scale + quantization.min);
}
} else {
throw new Error("Unsupported dtype in weight '" + name + "': " + dtype);
}
offset += size * quantizationSizeFactor;
} else if (dtype === 'string') {
var _size = sizeFromShape(spec.shape);
values = [];
for (var _i2 = 0; _i2 < _size; _i2++) {
var byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
offset += NUM_BYTES_STRING_LENGTH;
var bytes = new Uint8Array(buffer.slice(offset, offset + byteLength));
values.push(bytes);
offset += byteLength;
}
} else {
var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
var _byteBuffer = buffer.slice(offset, offset + size * dtypeFactor);
if (dtype === 'float32') {
values = new Float32Array(_byteBuffer);
} else if (dtype === 'int32') {
values = new Int32Array(_byteBuffer);
} else if (dtype === 'bool') {
values = new Uint8Array(_byteBuffer);
} else if (dtype === 'complex64') {
values = new Float32Array(_byteBuffer);
var real = new Float32Array(values.length / 2);
var image = new Float32Array(values.length / 2);
for (var _i3 = 0; _i3 < real.length; _i3++) {
real[_i3] = values[_i3 * 2];
image[_i3] = values[_i3 * 2 + 1];
}
var realTensor = tensor(real, shape, 'float32');
var imageTensor = tensor(image, shape, 'float32');
out[name] = complex(realTensor, imageTensor);
realTensor.dispose();
imageTensor.dispose();
} else {
throw new Error("Unsupported dtype in weight '" + name + "': " + dtype);
}
offset += size * dtypeFactor;
}
if (dtype !== 'complex64') {
out[name] = tensor(values, shape, dtype);
}
}
return out;
}
/**
* Concatenate TypedArrays into an ArrayBuffer.
*/
function concatenateTypedArrays(xs) {
// TODO(adarob, cais): Support quantization.
if (xs === null) {
throw new Error("Invalid input value: " + JSON.stringify(xs));
}
var totalByteLength = 0; // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer'
// can have a different byte length from that of the `TypedArray` itself,
// for example, when the `TypedArray` is created from an offset in an
// `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match
// the `TypedArray` in byte length. If an element of `xs` does not show
// this property, a new `TypedArray` that satisfy this property will be
// constructed and pushed into `normalizedXs`.
var normalizedXs = [];
xs.forEach(function (x) {
totalByteLength += x.byteLength; // tslint:disable:no-any
normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x));
if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) {
throw new Error("Unsupported TypedArray subtype: " + x.constructor.name);
} // tslint:enable:no-any
});
var y = new Uint8Array(totalByteLength);
var offset = 0;
normalizedXs.forEach(function (x) {
y.set(new Uint8Array(x.buffer), offset);
offset += x.byteLength;
});
return y.buffer;
} // Use Buffer on Node.js instead of Blob/atob/btoa
var useNodeBuffer = typeof Buffer !== 'undefined' && (typeof Blob === 'undefined' || typeof atob === 'undefined' || typeof btoa === 'undefined');
/**
* Calculate the byte length of a JavaScript string.
*
* Note that a JavaScript string can contain wide characters, therefore the
* length of the string is not necessarily equal to the byte length.
*
* @param str Input string.
* @returns Byte length.
*/
function stringByteLength(str) {
if (useNodeBuffer) {
return Buffer.byteLength(str);
}
return new Blob([str]).size;
}
/**
* Encode an ArrayBuffer as a base64 encoded string.
*
* @param buffer `ArrayBuffer` to be converted.
* @returns A string that base64-encodes `buffer`.
*/
function arrayBufferToBase64String(buffer) {
if (useNodeBuffer) {
return Buffer.from(buffer).toString('base64');
}
var buf = new Uint8Array(buffer);
var s = '';
for (var i = 0, l = buf.length; i < l; i++) {
s += String.fromCharCode(buf[i]);
}
return btoa(s);
}
/**
* Decode a base64 string as an ArrayBuffer.
*
* @param str Base64 string.
* @returns Decoded `ArrayBuffer`.
*/
function base64StringToArrayBuffer(str) {
if (useNodeBuffer) {
var buf = Buffer.from(str, 'base64');
return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
}
var s = atob(str);
var buffer = new Uint8Array(s.length);
for (var i = 0; i < s.length; ++i) {
buffer.set([s.charCodeAt(i)], i);
}
return buffer.buffer;
}
/**
* Concatenate a number of ArrayBuffers into one.
*
* @param buffers A number of array buffers to concatenate.
* @returns Result of concatenating `buffers` in order.
*/
function concatenateArrayBuffers(buffers) {
if (buffers.length === 1) {
return buffers[0];
}
var totalByteLength = 0;
buffers.forEach(function (buffer) {
totalByteLength += buffer.byteLength;
});
var temp = new Uint8Array(totalByteLength);
var offset = 0;
buffers.forEach(function (buffer) {
temp.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
});
return temp.buffer;
}
/**
* Get the basename of a path.
*
* Behaves in a way analogous to Linux's basename command.
*
* @param path
*/
function basename(path) {
var SEPARATOR = '/';
path = path.trim();
while (path.endsWith(SEPARATOR)) {
path = path.slice(0, path.length - 1);
}
var items = path.split(SEPARATOR);
return items[items.length - 1];
}
/**
* Create `ModelJSON` from `ModelArtifacts`.
*
* @param artifacts Model artifacts, describing the model and its weights.
* @param manifest Weight manifest, describing where the weights of the
* `ModelArtifacts` are stored, and some metadata about them.
* @returns Object representing the `model.json` file describing the model
* artifacts and weights
*/
function getModelJSONForModelArtifacts(artifacts, manifest) {
var result = {
modelTopology: artifacts.modelTopology,
format: artifacts.format,
generatedBy: artifacts.generatedBy,
convertedBy: artifacts.convertedBy,
weightsManifest: manifest
};
if (artifacts.signature != null) {
result.signature = artifacts.signature;
}
if (artifacts.userDefinedMetadata != null) {
result.userDefinedMetadata = artifacts.userDefinedMetadata;
}
if (artifacts.modelInitializer != null) {
result.modelInitializer = artifacts.modelInitializer;
}
if (artifacts.trainingConfig != null) {
result.trainingConfig = artifacts.trainingConfig;
}
return result;
}
/**
* Create `ModelArtifacts` from a JSON file.
*
* @param modelJSON Object containing the parsed JSON of `model.json`
* @param loadWeights Function that takes the JSON file's weights manifest,
* reads weights from the listed path(s), and returns a Promise of the
* weight manifest entries along with the weights data.
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
*/
function getModelArtifactsForJSON(_x3, _x4) {
return _getModelArtifactsForJSON.apply(this, arguments);
}
/**
* Populate ModelArtifactsInfo fields for a model with JSON topology.
* @param modelArtifacts
* @returns A ModelArtifactsInfo object.
*/
function _getModelArtifactsForJSON() {
_getModelArtifactsForJSON = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(modelJSON, loadWeights) {
var modelArtifacts, _yield$loadWeights, weightSpecs, weightData;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
modelArtifacts = {
modelTopology: modelJSON.modelTopology,
format: modelJSON.format,
generatedBy: modelJSON.generatedBy,
convertedBy: modelJSON.convertedBy
};
if (modelJSON.trainingConfig != null) {
modelArtifacts.trainingConfig = modelJSON.trainingConfig;
}
if (!(modelJSON.weightsManifest != null)) {
_context3.next = 10;
break;
}
_context3.next = 5;
return loadWeights(modelJSON.weightsManifest);
case 5:
_yield$loadWeights = _context3.sent;
weightSpecs = _yield$loadWeights[0];
weightData = _yield$loadWeights[1];
modelArtifacts.weightSpecs = weightSpecs;
modelArtifacts.weightData = weightData;
case 10:
if (modelJSON.signature != null) {
modelArtifacts.signature = modelJSON.signature;
}
if (modelJSON.userDefinedMetadata != null) {
modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
}
if (modelJSON.modelInitializer != null) {
modelArtifacts.modelInitializer = modelJSON.modelInitializer;
}
return _context3.abrupt("return", modelArtifacts);
case 14:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
return _getModelArtifactsForJSON.apply(this, arguments);
}
function getModelArtifactsInfoForJSON(modelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('Expected JSON model topology, received ArrayBuffer.');
}
return {
dateSaved: new Date(),
modelTopologyType: 'JSON',
modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength
};
}
/**
* Computes mantisa table for casting Float16 to Float32
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
*
* @returns Uint32Array, 2048 mantissa lookup values.
*/
function computeFloat16MantisaTable() {
var convertMantissa = function convertMantissa(i) {
var m = i << 13;
var e = 0;
while ((m & 0x00800000) === 0) {
e -= 0x00800000;
m <<= 1;
}
m &= ~0x00800000;
e += 0x38800000;
return m | e;
};
var mantisaTable = new Uint32Array(2048);
mantisaTable[0] = 0;
for (var i = 1; i < 1024; i++) {
mantisaTable[i] = convertMantissa(i);
}
for (var _i4 = 1024; _i4 < 2048; _i4++) {
mantisaTable[_i4] = 0x38000000 + (_i4 - 1024 << 13);
}
return mantisaTable;
}
/**
* Computes exponent table for casting Float16 to Float32
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
*
* @returns Uint32Array, 64 exponent lookup values.
*/
function computeFloat16ExponentTable() {
var exponentTable = new Uint32Array(64);
exponentTable[0] = 0;
exponentTable[31] = 0x47800000;
exponentTable[32] = 0x80000000;
exponentTable[63] = 0xc7800000;
for (var i = 1; i < 31; i++) {
exponentTable[i] = i << 23;
}
for (var _i5 = 33; _i5 < 63; _i5++) {
exponentTable[_i5] = 0x80000000 + (_i5 - 32 << 23);
}
return exponentTable;
}
/**
* Computes offset table for casting Float16 to Float32
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
*
* @returns Uint32Array, 6d offset values.
*/
function computeFloat16OffsetTable() {
var offsetTable = new Uint32Array(64);
for (var i = 0; i < 64; i++) {
offsetTable[i] = 1024;
}
offsetTable[0] = offsetTable[32] = 0;
return offsetTable;
}
/**
* Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
* to a Float32Array.
*
* @returns Function (buffer: Uint16Array) => Float32Array which decodes
* the Uint16Array of Float16 bytes to a Float32Array.
*/
function getFloat16Decoder() {
// Algorithm is based off of
// http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
// Cache lookup tables
var mantisaTable = computeFloat16MantisaTable();
var exponentTable = computeFloat16ExponentTable();
var offsetTable = computeFloat16OffsetTable();
return function (quantizedArray) {
var buffer = new ArrayBuffer(4 * quantizedArray.length);
var bufferUint32View = new Uint32Array(buffer);
for (var index = 0; index < quantizedArray.length; index++) {
var float16Bits = quantizedArray[index];
var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + exponentTable[float16Bits >> 10];
bufferUint32View[index] = float32Bits;
}
return new Float32Array(buffer);
};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var IORouterRegistry = /*#__PURE__*/function () {
function IORouterRegistry() {
this.saveRouters = [];
this.loadRouters = [];
}
IORouterRegistry.getInstance = function getInstance() {
if (IORouterRegistry.instance == null) {
IORouterRegistry.instance = new IORouterRegistry();
}
return IORouterRegistry.instance;
}
/**
* Register a save-handler router.
*
* @param saveRouter A function that maps a URL-like string onto an instance
* of `IOHandler` with the `save` method defined or `null`.
*/
;
IORouterRegistry.registerSaveRouter = function registerSaveRouter(saveRouter) {
IORouterRegistry.getInstance().saveRouters.push(saveRouter);
}
/**
* Register a load-handler router.
*
* @param loadRouter A function that maps a URL-like string onto an instance
* of `IOHandler` with the `load` method defined or `null`.
*/
;
IORouterRegistry.registerLoadRouter = function registerLoadRouter(loadRouter) {
IORouterRegistry.getInstance().loadRouters.push(loadRouter);
}
/**
* Look up IOHandler for saving, given a URL-like string.
*
* @param url
* @returns If only one match is found, an instance of IOHandler with the
* `save` method defined. If no match is found, `null`.
* @throws Error, if more than one match is found.
*/
;
IORouterRegistry.getSaveHandlers = function getSaveHandlers(url) {
return IORouterRegistry.getHandlers(url, 'save');
}
/**
* Look up IOHandler for loading, given a URL-like string.
*
* @param url
* @param loadOptions Optional, custom load options.
* @returns All valid handlers for `url`, given the currently registered
* handler routers.
*/
;
IORouterRegistry.getLoadHandlers = function getLoadHandlers(url, loadOptions) {
return IORouterRegistry.getHandlers(url, 'load', loadOptions);
};
IORouterRegistry.getHandlers = function getHandlers(url, handlerType, loadOptions) {
var validHandlers = [];
var routers = handlerType === 'load' ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters;
routers.forEach(function (router) {
var handler = router(url, loadOptions);
if (handler !== null) {
validHandlers.push(handler);
}
});
return validHandlers;
};
return IORouterRegistry;
}();
var registerSaveRouter = function registerSaveRouter(loudRouter) {
return IORouterRegistry.registerSaveRouter(loudRouter);
};
var registerLoadRouter = function registerLoadRouter(loudRouter) {
return IORouterRegistry.registerLoadRouter(loudRouter);
};
var getSaveHandlers = function getSaveHandlers(url) {
return IORouterRegistry.getSaveHandlers(url);
};
var getLoadHandlers = function getLoadHandlers(url, loadOptions) {
return IORouterRegistry.getLoadHandlers(url, loadOptions);
};
var DATABASE_NAME = 'tensorflowjs';
var DATABASE_VERSION = 1; // Model data and ModelArtifactsInfo (metadata) are stored in two separate
// stores for efficient access of the list of stored models and their metadata.
// 1. The object store for model data: topology, weights and weight manifests.
var MODEL_STORE_NAME = 'models_store'; // 2. The object store for ModelArtifactsInfo, including meta-information such
// as the type of topology (JSON vs binary), byte size of the topology, byte
// size of the weights, etc.
var INFO_STORE_NAME = 'model_info_store';
/**
* Delete the entire database for tensorflow.js, including the models store.
*/
function deleteDatabase() {
return _deleteDatabase.apply(this, arguments);
}
function _deleteDatabase() {
_deleteDatabase = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5() {
var idbFactory;
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
idbFactory = getIndexedDBFactory();
return _context5.abrupt("return", new Promise(function (resolve, reject) {
var deleteRequest = idbFactory.deleteDatabase(DATABASE_NAME);
deleteRequest.onsuccess = function () {
return resolve();
};
deleteRequest.onerror = function (error) {
return reject(error);
};
}));
case 2:
case "end":
return _context5.stop();
}
}
}, _callee5);
}));
return _deleteDatabase.apply(this, arguments);
}
function getIndexedDBFactory() {
if (!env().getBool('IS_BROWSER')) {
// TODO(cais): Add more info about what IOHandler subtypes are available.
// Maybe point to a doc page on the web and/or automatically determine
// the available IOHandlers and print them in the error message.
throw new Error('Failed to obtain IndexedDB factory because the current environment' + 'is not a web browser.');
} // tslint:disable-next-line:no-any
var theWindow = typeof window === 'undefined' ? self : window;
var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB;
if (factory == null) {
throw new Error('The current browser does not appear to support IndexedDB.');
}
return factory;
}
function setUpDatabase(openRequest) {
var db = openRequest.result;
db.createObjectStore(MODEL_STORE_NAME, {
keyPath: 'modelPath'
});
db.createObjectStore(INFO_STORE_NAME, {
keyPath: 'modelPath'
});
}
/**
* IOHandler subclass: Browser IndexedDB.
*
* See the doc string of `browserIndexedDB` for more details.
*/
var BrowserIndexedDB = /*#__PURE__*/function () {
function BrowserIndexedDB(modelPath) {
this.indexedDB = getIndexedDBFactory();
if (modelPath == null || !modelPath) {
throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
}
this.modelPath = modelPath;
}
var _proto = BrowserIndexedDB.prototype;
_proto.save = /*#__PURE__*/function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(modelArtifacts) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
_context.next = 2;
break;
}
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.');
case 2:
return _context.abrupt("return", this.databaseAction(this.modelPath, modelArtifacts));
case 3:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function save(_x) {
return _save.apply(this, arguments);
}
return save;
}();
_proto.load = /*#__PURE__*/function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
return _context2.abrupt("return", this.databaseAction(this.modelPath));
case 1:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}()
/**
* Perform database action to put model artifacts into or read model artifacts
* from IndexedDB object store.
*
* Whether the action is put or get depends on whether `modelArtifacts` is
* specified. If it is specified, the action will be put; otherwise the action
* will be get.
*
* @param modelPath A unique string path for the model.
* @param modelArtifacts If specified, it will be the model artifacts to be
* stored in IndexedDB.
* @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise`
* of `ModelArtifacts`, if the action is get.
*/
;
_proto.databaseAction = function databaseAction(modelPath, modelArtifacts) {
var _this = this;
return new Promise(function (resolve, reject) {
var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function () {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function () {
var db = openRequest.result;
if (modelArtifacts == null) {
// Read model out from object store.
var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
var getRequest = modelStore.get(_this.modelPath);
getRequest.onsuccess = function () {
if (getRequest.result == null) {
db.close();
return reject(new Error("Cannot find model with path '" + _this.modelPath + "' " + "in IndexedDB."));
} else {
resolve(getRequest.result.modelArtifacts);
}
};
getRequest.onerror = function (error) {
db.close();
return reject(getRequest.error);
};
modelTx.oncomplete = function () {
return db.close();
};
} else {
// Put model into object store.
var modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); // First, put ModelArtifactsInfo into info store.
var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
var infoStore = infoTx.objectStore(INFO_STORE_NAME);
var putInfoRequest = infoStore.put({
modelPath: _this.modelPath,
modelArtifactsInfo: modelArtifactsInfo
});
var _modelTx;
putInfoRequest.onsuccess = function () {
// Second, put model data into model store.
_modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
var modelStore = _modelTx.objectStore(MODEL_STORE_NAME);
var putModelRequest = modelStore.put({
modelPath: _this.modelPath,
modelArtifacts: modelArtifacts,
modelArtifactsInfo: modelArtifactsInfo
});
putModelRequest.onsuccess = function () {
return resolve({
modelArtifactsInfo: modelArtifactsInfo
});
};
putModelRequest.onerror = function (error) {
// If the put-model request fails, roll back the info entry as
// well.
infoStore = infoTx.objectStore(INFO_STORE_NAME);
var deleteInfoRequest = infoStore.delete(_this.modelPath);
deleteInfoRequest.onsuccess = function () {
db.close();
return reject(putModelRequest.error);
};
deleteInfoRequest.onerror = function (error) {
db.close();
return reject(putModelRequest.error);
};
};
};
putInfoRequest.onerror = function (error) {
db.close();
return reject(putInfoRequest.error);
};
infoTx.oncomplete = function () {
if (_modelTx == null) {
db.close();
} else {
_modelTx.oncomplete = function () {
return db.close();
};
}
};
}
};
openRequest.onerror = function (error) {
return reject(openRequest.error);
};
});
};
return BrowserIndexedDB;
}();
BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
var indexedDBRouter = function indexedDBRouter(url) {
if (!env().getBool('IS_BROWSER')) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(indexedDBRouter);
IORouterRegistry.registerLoadRouter(indexedDBRouter);
/**
* Creates a browser IndexedDB IOHandler for saving and loading models.
*
* ```js
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
*
* const saveResult = await model.save('indexeddb://MyModel'));
* console.log(saveResult);
* ```
*
* @param modelPath A unique identifier for the model to be saved. Must be a
* non-empty string.
* @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`),
* which can be used with, e.g., `tf.Model.save`.
*/
function browserIndexedDB(modelPath) {
return new BrowserIndexedDB(modelPath);
}
function maybeStripScheme(key) {
return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key;
}
var BrowserIndexedDBManager = /*#__PURE__*/function () {
function BrowserIndexedDBManager() {
this.indexedDB = getIndexedDBFactory();
}
var _proto2 = BrowserIndexedDBManager.prototype;
_proto2.listModels = /*#__PURE__*/function () {
var _listModels = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var _this2 = this;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
return _context3.abrupt("return", new Promise(function (resolve, reject) {
var openRequest = _this2.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function () {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function () {
var db = openRequest.result;
var tx = db.transaction(INFO_STORE_NAME, 'readonly');
var store = tx.objectStore(INFO_STORE_NAME); // tslint:disable:max-line-length
// Need to cast `store` as `any` here because TypeScript's DOM
// library does not have the `getAll()` method even though the
// method is supported in the latest version of most mainstream
// browsers:
// https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll
// tslint:enable:max-line-length
// tslint:disable-next-line:no-any
var getAllInfoRequest = store.getAll();
getAllInfoRequest.onsuccess = function () {
var out = {};
for (var _iterator = _createForOfIteratorHelperLoose(getAllInfoRequest.result), _step; !(_step = _iterator()).done;) {
var item = _step.value;
out[item.modelPath] = item.modelArtifactsInfo;
}
resolve(out);
};
getAllInfoRequest.onerror = function (error) {
db.close();
return reject(getAllInfoRequest.error);
};
tx.oncomplete = function () {
return db.close();
};
};
openRequest.onerror = function (error) {
return reject(openRequest.error);
};
}));
case 1:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
function listModels() {
return _listModels.apply(this, arguments);
}
return listModels;
}();
_proto2.removeModel = /*#__PURE__*/function () {
var _removeModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(path) {
var _this3 = this;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
path = maybeStripScheme(path);
return _context4.abrupt("return", new Promise(function (resolve, reject) {
var openRequest = _this3.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = function () {
return setUpDatabase(openRequest);
};
openRequest.onsuccess = function () {
var db = openRequest.result;
var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
var infoStore = infoTx.objectStore(INFO_STORE_NAME);
var getInfoRequest = infoStore.get(path);
var modelTx;
getInfoRequest.onsuccess = function () {
if (getInfoRequest.result == null) {
db.close();
return reject(new Error("Cannot find model with path '" + path + "' " + "in IndexedDB."));
} else {
// First, delete the entry in the info store.
var deleteInfoRequest = infoStore.delete(path);
var deleteModelData = function deleteModelData() {
// Second, delete the entry in the model store.
modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
var modelStore = modelTx.objectStore(MODEL_STORE_NAME);
var deleteModelRequest = modelStore.delete(path);
deleteModelRequest.onsuccess = function () {
return resolve(getInfoRequest.result.modelArtifactsInfo);
};
deleteModelRequest.onerror = function (error) {
return reject(getInfoRequest.error);
};
}; // Proceed with deleting model data regardless of whether deletion
// of info data succeeds or not.
deleteInfoRequest.onsuccess = deleteModelData;
deleteInfoRequest.onerror = function (error) {
deleteModelData();
db.close();
return reject(getInfoRequest.error);
};
}
};
getInfoRequest.onerror = function (error) {
db.close();
return reject(getInfoRequest.error);
};
infoTx.oncomplete = function () {
if (modelTx == null) {
db.close();
} else {
modelTx.oncomplete = function () {
return db.close();
};
}
};
};
openRequest.onerror = function (error) {
return reject(openRequest.error);
};
}));
case 2:
case "end":
return _context4.stop();
}
}
}, _callee4);
}));
function removeModel(_x2) {
return _removeModel.apply(this, arguments);
}
return removeModel;
}();
return BrowserIndexedDBManager;
}();
var PATH_SEPARATOR = '/';
var PATH_PREFIX = 'tensorflowjs_models';
var INFO_SUFFIX = 'info';
var MODEL_TOPOLOGY_SUFFIX = 'model_topology';
var WEIGHT_SPECS_SUFFIX = 'weight_specs';
var WEIGHT_DATA_SUFFIX = 'weight_data';
var MODEL_METADATA_SUFFIX = 'model_metadata';
/**
* Purge all tensorflow.js-saved model artifacts from local storage.
*
* @returns Paths of the models purged.
*/
function purgeLocalStorageArtifacts() {
if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') {
throw new Error('purgeLocalStorageModels() cannot proceed because local storage is ' + 'unavailable in the current environment.');
}
var LS = window.localStorage;
var purgedModelPaths = [];
for (var i = 0; i < LS.length; ++i) {
var key = LS.key(i);
var prefix = PATH_PREFIX + PATH_SEPARATOR;
if (key.startsWith(prefix) && key.length > prefix.length) {
LS.removeItem(key);
var modelName = getModelPathFromKey(key);
if (purgedModelPaths.indexOf(modelName) === -1) {
purgedModelPaths.push(modelName);
}
}
}
return purgedModelPaths;
}
function getModelKeys(path) {
return {
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
};
}
function removeItems(keys) {
for (var _i = 0, _Object$values = Object.values(keys); _i < _Object$values.length; _i++) {
var key = _Object$values[_i];
window.localStorage.removeItem(key);
}
}
/**
* Get model path from a local-storage key.
*
* E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1'
*
* @param key
*/
function getModelPathFromKey(key) {
var items = key.split(PATH_SEPARATOR);
if (items.length < 3) {
throw new Error("Invalid key format: " + key);
}
return items.slice(1, items.length - 1).join(PATH_SEPARATOR);
}
function maybeStripScheme$1(key) {
return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key;
}
/**
* IOHandler subclass: Browser Local Storage.
*
* See the doc string to `browserLocalStorage` for more details.
*/
var BrowserLocalStorage = /*#__PURE__*/function () {
function BrowserLocalStorage(modelPath) {
if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') {
// TODO(cais): Add more info about what IOHandler subtypes are
// available.
// Maybe point to a doc page on the web and/or automatically determine
// the available IOHandlers and print them in the error message.
throw new Error('The current environment does not support local storage.');
}
this.LS = window.localStorage;
if (modelPath == null || !modelPath) {
throw new Error('For local storage, modelPath must not be null, undefined or empty.');
}
this.modelPath = modelPath;
this.keys = getModelKeys(this.modelPath);
}
/**
* Save model artifacts to browser local storage.
*
* See the documentation to `browserLocalStorage` for details on the saved
* artifacts.
*
* @param modelArtifacts The model artifacts to be stored.
* @returns An instance of SaveResult.
*/
var _proto = BrowserLocalStorage.prototype;
_proto.save =
/*#__PURE__*/
function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(modelArtifacts) {
var topology, weightSpecs, modelArtifactsInfo, metadata;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
_context.next = 4;
break;
}
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.');
case 4:
topology = JSON.stringify(modelArtifacts.modelTopology);
weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
_context.prev = 7;
this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
this.LS.setItem(this.keys.topology, topology);
this.LS.setItem(this.keys.weightSpecs, weightSpecs);
this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); // Note that JSON.stringify doesn't write out keys that have undefined
// values, so for some keys, we set undefined instead of a null-ish
// value.
metadata = {
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy,
signature: modelArtifacts.signature != null ? modelArtifacts.signature : undefined,
userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ? modelArtifacts.userDefinedMetadata : undefined,
modelInitializer: modelArtifacts.modelInitializer != null ? modelArtifacts.modelInitializer : undefined,
trainingConfig: modelArtifacts.trainingConfig != null ? modelArtifacts.trainingConfig : undefined
};
this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
return _context.abrupt("return", {
modelArtifactsInfo: modelArtifactsInfo
});
case 17:
_context.prev = 17;
_context.t0 = _context["catch"](7);
// If saving failed, clean up all items saved so far.
removeItems(this.keys);
throw new Error("Failed to save model '" + this.modelPath + "' to local storage: " + "size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + "."));
case 21:
case "end":
return _context.stop();
}
}
}, _callee, this, [[7, 17]]);
}));
function save(_x) {
return _save.apply(this, arguments);
}
return save;
}()
/**
* Load a model from local storage.
*
* See the documentation to `browserLocalStorage` for details on the saved
* artifacts.
*
* @returns The loaded model (if loading succeeds).
*/
;
_proto.load =
/*#__PURE__*/
function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
info = JSON.parse(this.LS.getItem(this.keys.info));
if (!(info == null)) {
_context2.next = 3;
break;
}
throw new Error("In local storage, there is no model with name '" + this.modelPath + "'");
case 3:
if (!(info.modelTopologyType !== 'JSON')) {
_context2.next = 5;
break;
}
throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.');
case 5:
out = {}; // Load topology.
topology = JSON.parse(this.LS.getItem(this.keys.topology));
if (!(topology == null)) {
_context2.next = 9;
break;
}
throw new Error("In local storage, the topology of model '" + this.modelPath + "' " + "is missing.");
case 9:
out.modelTopology = topology; // Load weight specs.
weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
if (!(weightSpecs == null)) {
_context2.next = 13;
break;
}
throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' " + "are missing.");
case 13:
out.weightSpecs = weightSpecs; // Load meta-data fields.
metadataString = this.LS.getItem(this.keys.modelMetadata);
if (metadataString != null) {
metadata = JSON.parse(metadataString);
out.format = metadata.format;
out.generatedBy = metadata.generatedBy;
out.convertedBy = metadata.convertedBy;
if (metadata.signature != null) {
out.signature = metadata.signature;
}
if (metadata.userDefinedMetadata != null) {
out.userDefinedMetadata = metadata.userDefinedMetadata;
}
if (metadata.modelInitializer != null) {
out.modelInitializer = metadata.modelInitializer;
}
if (metadata.trainingConfig != null) {
out.trainingConfig = metadata.trainingConfig;
}
} // Load weight data.
weightDataBase64 = this.LS.getItem(this.keys.weightData);
if (!(weightDataBase64 == null)) {
_context2.next = 19;
break;
}
throw new Error("In local storage, the binary weight values of model " + ("'" + this.modelPath + "' are missing."));
case 19:
out.weightData = base64StringToArrayBuffer(weightDataBase64);
return _context2.abrupt("return", out);
case 21:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}();
return BrowserLocalStorage;
}();
BrowserLocalStorage.URL_SCHEME = 'localstorage://';
var localStorageRouter = function localStorageRouter(url) {
if (!env().getBool('IS_BROWSER')) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(localStorageRouter);
IORouterRegistry.registerLoadRouter(localStorageRouter);
/**
* Factory function for local storage IOHandler.
*
* This `IOHandler` supports both `save` and `load`.
*
* For each model's saved artifacts, four items are saved to local storage.
* - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the
* model, such as date saved, type of the topology, size in bytes, etc.
* - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras-
* style models, this is a stringized JSON.
* - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the
* model, can be used to decode the saved binary weight values (see
* item below).
* - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary
* weight values, stored as a base64-encoded string.
*
* Saving may throw an `Error` if the total size of the artifacts exceed the
* browser-specific quota.
*
* @param modelPath A unique identifier for the model to be saved. Must be a
* non-empty string.
* @returns An instance of `IOHandler`, which can be used with, e.g.,
* `tf.Model.save`.
*/
function browserLocalStorage(modelPath) {
return new BrowserLocalStorage(modelPath);
}
var BrowserLocalStorageManager = /*#__PURE__*/function () {
function BrowserLocalStorageManager() {
assert(env().getBool('IS_BROWSER'), function () {
return 'Current environment is not a web browser';
});
assert(typeof window === 'undefined' || typeof window.localStorage !== 'undefined', function () {
return 'Current browser does not appear to support localStorage';
});
this.LS = window.localStorage;
}
var _proto2 = BrowserLocalStorageManager.prototype;
_proto2.listModels = /*#__PURE__*/function () {
var _listModels = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var out, prefix, suffix, i, key, modelPath;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
out = {};
prefix = PATH_PREFIX + PATH_SEPARATOR;
suffix = PATH_SEPARATOR + INFO_SUFFIX;
for (i = 0; i < this.LS.length; ++i) {
key = this.LS.key(i);
if (key.startsWith(prefix) && key.endsWith(suffix)) {
modelPath = getModelPathFromKey(key);
out[modelPath] = JSON.parse(this.LS.getItem(key));
}
}
return _context3.abrupt("return", out);
case 5:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function listModels() {
return _listModels.apply(this, arguments);
}
return listModels;
}();
_proto2.removeModel = /*#__PURE__*/function () {
var _removeModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(path) {
var keys, info;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
path = maybeStripScheme$1(path);
keys = getModelKeys(path);
if (!(this.LS.getItem(keys.info) == null)) {
_context4.next = 4;
break;
}
throw new Error("Cannot find model at path '" + path + "'");
case 4:
info = JSON.parse(this.LS.getItem(keys.info));
removeItems(keys);
return _context4.abrupt("return", info);
case 7:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function removeModel(_x2) {
return _removeModel.apply(this, arguments);
}
return removeModel;
}();
return BrowserLocalStorageManager;
}();
var URL_SCHEME_SUFFIX = '://';
var ModelStoreManagerRegistry = /*#__PURE__*/function () {
function ModelStoreManagerRegistry() {
this.managers = {};
}
ModelStoreManagerRegistry.getInstance = function getInstance() {
if (ModelStoreManagerRegistry.instance == null) {
ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry();
}
return ModelStoreManagerRegistry.instance;
}
/**
* Register a save-handler router.
*
* @param saveRouter A function that maps a URL-like string onto an instance
* of `IOHandler` with the `save` method defined or `null`.
*/
;
ModelStoreManagerRegistry.registerManager = function registerManager(scheme, manager) {
assert(scheme != null, function () {
return 'scheme must not be undefined or null.';
});
if (scheme.endsWith(URL_SCHEME_SUFFIX)) {
scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX));
}
assert(scheme.length > 0, function () {
return 'scheme must not be an empty string.';
});
var registry = ModelStoreManagerRegistry.getInstance();
assert(registry.managers[scheme] == null, function () {
return "A model store manager is already registered for scheme '" + scheme + "'.";
});
registry.managers[scheme] = manager;
};
ModelStoreManagerRegistry.getManager = function getManager(scheme) {
var manager = this.getInstance().managers[scheme];
if (manager == null) {
throw new Error("Cannot find model manager for scheme '" + scheme + "'");
}
return manager;
};
ModelStoreManagerRegistry.getSchemes = function getSchemes() {
return Object.keys(this.getInstance().managers);
};
return ModelStoreManagerRegistry;
}();
/**
* Helper method for parsing a URL string into a scheme and a path.
*
* @param url E.g., 'localstorage://my-model'
* @returns A dictionary with two fields: scheme and path.
* Scheme: e.g., 'localstorage' in the example above.
* Path: e.g., 'my-model' in the example above.
*/
function parseURL$1(url) {
if (url.indexOf(URL_SCHEME_SUFFIX) === -1) {
throw new Error("The url string provided does not contain a scheme. " + "Supported schemes are: " + ("" + ModelStoreManagerRegistry.getSchemes().join(',')));
}
return {
scheme: url.split(URL_SCHEME_SUFFIX)[0],
path: url.split(URL_SCHEME_SUFFIX)[1]
};
}
function cloneModelInternal(_x, _x2, _x3) {
return _cloneModelInternal.apply(this, arguments);
}
/**
* List all models stored in registered storage mediums.
*
* For a web browser environment, the registered mediums are Local Storage and
* IndexedDB.
*
* ```js
* // First create and save a model.
* const model = tf.sequential();
* model.add(tf.layers.dense(
* {units: 1, inputShape: [10], activation: 'sigmoid'}));
* await model.save('localstorage://demo/management/model1');
*
* // Then list existing models.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Delete the model.
* await tf.io.removeModel('localstorage://demo/management/model1');
*
* // List models again.
* console.log(JSON.stringify(await tf.io.listModels()));
* ```
*
* @returns A `Promise` of a dictionary mapping URLs of existing models to
* their model artifacts info. URLs include medium-specific schemes, e.g.,
* 'indexeddb://my/model/1'. Model artifacts info include type of the
* model's topology, byte sizes of the topology, weights, etc.
*
* @doc {
* heading: 'Models',
* subheading: 'Management',
* namespace: 'io',
* ignoreCI: true
* }
*/
function _cloneModelInternal() {
_cloneModelInternal = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(sourceURL, destURL, deleteSource) {
var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (deleteSource === void 0) {
deleteSource = false;
}
assert(sourceURL !== destURL, function () {
return "Old path and new path are the same: '" + sourceURL + "'";
});
loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL);
assert(loadHandlers.length > 0, function () {
return "Copying failed because no load handler is found for source URL " + sourceURL + ".";
});
assert(loadHandlers.length < 2, function () {
return "Copying failed because more than one (" + loadHandlers.length + ") " + ("load handlers for source URL " + sourceURL + ".");
});
loadHandler = loadHandlers[0];
saveHandlers = IORouterRegistry.getSaveHandlers(destURL);
assert(saveHandlers.length > 0, function () {
return "Copying failed because no save handler is found for destination " + ("URL " + destURL + ".");
});
assert(saveHandlers.length < 2, function () {
return "Copying failed because more than one (" + loadHandlers.length + ") " + ("save handlers for destination URL " + destURL + ".");
});
saveHandler = saveHandlers[0];
sourceScheme = parseURL$1(sourceURL).scheme;
sourcePath = parseURL$1(sourceURL).path;
sameMedium = sourceScheme === parseURL$1(sourceURL).scheme;
_context.next = 15;
return loadHandler.load();
case 15:
modelArtifacts = _context.sent;
if (!(deleteSource && sameMedium)) {
_context.next = 19;
break;
}
_context.next = 19;
return ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath);
case 19:
_context.next = 21;
return saveHandler.save(modelArtifacts);
case 21:
saveResult = _context.sent;
if (!(deleteSource && !sameMedium)) {
_context.next = 25;
break;
}
_context.next = 25;
return ModelStoreManagerRegistry.getManager(sourceScheme).removeModel(sourcePath);
case 25:
return _context.abrupt("return", saveResult.modelArtifactsInfo);
case 26:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _cloneModelInternal.apply(this, arguments);
}
function listModels() {
return _listModels.apply(this, arguments);
}
/**
* Remove a model specified by URL from a reigstered storage medium.
*
* ```js
* // First create and save a model.
* const model = tf.sequential();
* model.add(tf.layers.dense(
* {units: 1, inputShape: [10], activation: 'sigmoid'}));
* await model.save('localstorage://demo/management/model1');
*
* // Then list existing models.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Delete the model.
* await tf.io.removeModel('localstorage://demo/management/model1');
*
* // List models again.
* console.log(JSON.stringify(await tf.io.listModels()));
* ```
*
* @param url A URL to a stored model, with a scheme prefix, e.g.,
* 'localstorage://my-model-1', 'indexeddb://my/model/2'.
* @returns ModelArtifactsInfo of the deleted model (if and only if deletion
* is successful).
* @throws Error if deletion fails, e.g., if no model exists at `path`.
*
* @doc {
* heading: 'Models',
* subheading: 'Management',
* namespace: 'io',
* ignoreCI: true
* }
*/
function _listModels() {
_listModels = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var schemes, out, _iterator, _step, scheme, schemeOut, path, url;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
schemes = ModelStoreManagerRegistry.getSchemes();
out = {};
_iterator = _createForOfIteratorHelperLoose(schemes);
case 3:
if ((_step = _iterator()).done) {
_context2.next = 11;
break;
}
scheme = _step.value;
_context2.next = 7;
return ModelStoreManagerRegistry.getManager(scheme).listModels();
case 7:
schemeOut = _context2.sent;
for (path in schemeOut) {
url = scheme + URL_SCHEME_SUFFIX + path;
out[url] = schemeOut[path];
}
case 9:
_context2.next = 3;
break;
case 11:
return _context2.abrupt("return", out);
case 12:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
return _listModels.apply(this, arguments);
}
function removeModel(_x4) {
return _removeModel.apply(this, arguments);
}
/**
* Copy a model from one URL to another.
*
* This function supports:
*
* 1. Copying within a storage medium, e.g.,
* `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')`
* 2. Copying between two storage mediums, e.g.,
* `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')`
*
* ```js
* // First create and save a model.
* const model = tf.sequential();
* model.add(tf.layers.dense(
* {units: 1, inputShape: [10], activation: 'sigmoid'}));
* await model.save('localstorage://demo/management/model1');
*
* // Then list existing models.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Copy the model, from Local Storage to IndexedDB.
* await tf.io.copyModel(
* 'localstorage://demo/management/model1',
* 'indexeddb://demo/management/model1');
*
* // List models again.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Remove both models.
* await tf.io.removeModel('localstorage://demo/management/model1');
* await tf.io.removeModel('indexeddb://demo/management/model1');
* ```
*
* @param sourceURL Source URL of copying.
* @param destURL Destination URL of copying.
* @returns ModelArtifactsInfo of the copied model (if and only if copying
* is successful).
* @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or
* if `oldPath` and `newPath` are identical.
*
* @doc {
* heading: 'Models',
* subheading: 'Management',
* namespace: 'io',
* ignoreCI: true
* }
*/
function _removeModel() {
_removeModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(url) {
var schemeAndPath, manager;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
schemeAndPath = parseURL$1(url);
manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme);
return _context3.abrupt("return", manager.removeModel(schemeAndPath.path));
case 3:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
return _removeModel.apply(this, arguments);
}
function copyModel(_x5, _x6) {
return _copyModel.apply(this, arguments);
}
/**
* Move a model from one URL to another.
*
* This function supports:
*
* 1. Moving within a storage medium, e.g.,
* `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')`
* 2. Moving between two storage mediums, e.g.,
* `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')`
*
* ```js
* // First create and save a model.
* const model = tf.sequential();
* model.add(tf.layers.dense(
* {units: 1, inputShape: [10], activation: 'sigmoid'}));
* await model.save('localstorage://demo/management/model1');
*
* // Then list existing models.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Move the model, from Local Storage to IndexedDB.
* await tf.io.moveModel(
* 'localstorage://demo/management/model1',
* 'indexeddb://demo/management/model1');
*
* // List models again.
* console.log(JSON.stringify(await tf.io.listModels()));
*
* // Remove the moved model.
* await tf.io.removeModel('indexeddb://demo/management/model1');
* ```
*
* @param sourceURL Source URL of moving.
* @param destURL Destination URL of moving.
* @returns ModelArtifactsInfo of the copied model (if and only if copying
* is successful).
* @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or
* if `oldPath` and `newPath` are identical.
*
* @doc {
* heading: 'Models',
* subheading: 'Management',
* namespace: 'io',
* ignoreCI: true
* }
*/
function _copyModel() {
_copyModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(sourceURL, destURL) {
var deleteSource;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
deleteSource = false;
return _context4.abrupt("return", cloneModelInternal(sourceURL, destURL, deleteSource));
case 2:
case "end":
return _context4.stop();
}
}
}, _callee4);
}));
return _copyModel.apply(this, arguments);
}
function moveModel(_x7, _x8) {
return _moveModel.apply(this, arguments);
}
function _moveModel() {
_moveModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5(sourceURL, destURL) {
var deleteSource;
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
deleteSource = true;
return _context5.abrupt("return", cloneModelInternal(sourceURL, destURL, deleteSource));
case 2:
case "end":
return _context5.stop();
}
}
}, _callee5);
}));
return _moveModel.apply(this, arguments);
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PlatformBrowser = /*#__PURE__*/function () {
function PlatformBrowser() {}
var _proto = PlatformBrowser.prototype;
_proto.fetch = function (_fetch) {
function fetch(_x, _x2) {
return _fetch.apply(this, arguments);
}
fetch.toString = function () {
return _fetch.toString();
};
return fetch;
}(function (path, init) {
return fetch(path, init);
});
_proto.now = function now() {
return performance.now();
};
_proto.encode = function encode(text, encoding) {
if (encoding !== 'utf-8' && encoding !== 'utf8') {
throw new Error("Browser's encoder only supports utf-8, but got " + encoding);
}
if (this.textEncoder == null) {
this.textEncoder = new TextEncoder();
}
return this.textEncoder.encode(text);
};
_proto.decode = function decode(bytes, encoding) {
return new TextDecoder(encoding).decode(bytes);
};
return PlatformBrowser;
}();
if (env().get('IS_BROWSER')) {
env().setPlatform('browser', new PlatformBrowser()); // Register LocalStorage IOHandler
try {
ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager());
} catch (err) {} // Register IndexedDB IOHandler
try {
ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager());
} catch (err) {}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var getNodeFetch = {
// tslint:disable-next-line:no-require-imports
importFetch: function importFetch() {
return require('node-fetch');
}
};
var systemFetch; // These getters and setters are for testing so we don't export a mutable
// variable.
function resetSystemFetch() {
systemFetch = null;
}
function setSystemFetch(fetchFn) {
systemFetch = fetchFn;
}
function getSystemFetch() {
return systemFetch;
}
var PlatformNode = /*#__PURE__*/function () {
function PlatformNode() {
// tslint:disable-next-line:no-require-imports
this.util = require('util'); // According to the spec, the built-in encoder can do only UTF-8 encoding.
// https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder
this.textEncoder = new this.util.TextEncoder();
}
var _proto = PlatformNode.prototype;
_proto.fetch = function fetch(path, requestInits) {
if (env().global.fetch != null) {
return env().global.fetch(path, requestInits);
}
if (systemFetch == null) {
systemFetch = getNodeFetch.importFetch();
}
return systemFetch(path, requestInits);
};
_proto.now = function now() {
var time = process.hrtime();
return time[0] * 1000 + time[1] / 1000000;
};
_proto.encode = function encode(text, encoding) {
if (encoding !== 'utf-8' && encoding !== 'utf8') {
throw new Error("Node built-in encoder only supports utf-8, but got " + encoding);
}
return this.textEncoder.encode(text);
};
_proto.decode = function decode(bytes, encoding) {
if (bytes.length === 0) {
return '';
}
return new this.util.TextDecoder(encoding).decode(bytes);
};
return PlatformNode;
}();
if (env().get('IS_NODE')) {
env().setPlatform('node', new PlatformNode());
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`.
*
* The values are stored in CPU as `TypedArray`. Fill the buffer using
* `buffer.set()`, or by modifying directly `buffer.values`.
*
* When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with
* those values.
*
* ```js
* // Create a buffer and set values at particular indices.
* const buffer = tf.buffer([2, 2]);
* buffer.set(3, 0, 0);
* buffer.set(5, 1, 0);
*
* // Convert the buffer back to a tensor.
* buffer.toTensor().print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param dtype The dtype of the buffer. Defaults to 'float32'.
* @param values The values of the buffer as `TypedArray`. Defaults to
* zeros.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function buffer(shape, dtype, values) {
if (dtype === void 0) {
dtype = 'float32';
}
dtype = dtype || 'float32';
assertNonNegativeIntegerDimensions(shape);
return new TensorBuffer(shape, dtype, values);
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Casts a `tf.Tensor` to a new dtype.
*
* ```js
* const x = tf.tensor1d([1.5, 2.5, 3]);
* tf.cast(x, 'int32').print();
* ```
* @param x The input tensor to be casted.
* @param dtype The dtype to cast the input tensor to.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function cast_(x, dtype) {
var $x = convertToTensor(x, 'x', 'cast'); // Sanity checks.
if (!isValidDtype(dtype)) {
throw new Error("Failed to cast to unknown dtype " + dtype);
}
if (dtype === 'string' && $x.dtype !== 'string' || dtype !== 'string' && $x.dtype === 'string') {
throw new Error('Only strings can be casted to strings');
}
var inputs = {
x: $x
};
var attrs = {
dtype: dtype
};
return ENGINE.runKernel(Cast, inputs, attrs);
}
var cast = op({
cast_: cast_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a new tensor with the same values and shape as the specified
* tensor.
*
* ```js
* const x = tf.tensor([1, 2]);
*
* x.clone().print();
* ```
*
* @param x The tensor to clone.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function clone_(x) {
var $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
var inputs = {
x: $x
}; // Note this op is called tf.identity in python. Hence the kernel name used
// here.
return ENGINE.runKernel(Identity, inputs);
}
var clone = op({
clone_: clone_
});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Prints information about the `tf.Tensor` including its data.
*
* ```js
* const verbose = true;
* tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose);
* ```
* @param x The tensor to be printed.
* @param verbose Whether to print verbose information about the ` Tensor`,
* including dtype and size.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function print(x, verbose) {
if (verbose === void 0) {
verbose = false;
}
console.log(x.toString(verbose));
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getOrMakeEngine(); // Register backend-agnostic flags.
var opHandler$1 = {
buffer: buffer,
cast: cast,
clone: clone,
print: print
};
setOpHandler(opHandler$1);
var DEFAULT_FILE_NAME_PREFIX = 'model';
var DEFAULT_JSON_EXTENSION_NAME = '.json';
var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
function defer$1(f) {
return new Promise(function (resolve) {
return setTimeout(resolve);
}).then(f);
}
var BrowserDownloads = /*#__PURE__*/function () {
function BrowserDownloads(fileNamePrefix) {
if (!env().getBool('IS_BROWSER')) {
// TODO(cais): Provide info on what IOHandlers are available under the
// current environment.
throw new Error('browserDownloads() cannot proceed because the current environment ' + 'is not a browser.');
}
if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
}
if (fileNamePrefix == null || fileNamePrefix.length === 0) {
fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
}
this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
}
var _proto = BrowserDownloads.prototype;
_proto.save = /*#__PURE__*/function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(modelArtifacts) {
var weightsURL, weightsManifest, modelJSON, modelJsonURL, jsonAnchor, weightDataAnchor;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(typeof document === 'undefined')) {
_context.next = 2;
break;
}
throw new Error('Browser downloads are not supported in ' + 'this environment since `document` is not present');
case 2:
weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], {
type: 'application/octet-stream'
}));
if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
_context.next = 7;
break;
}
throw new Error('BrowserDownloads.save() does not support saving model topology ' + 'in binary formats yet.');
case 7:
weightsManifest = [{
paths: ['./' + this.weightDataFileName],
weights: modelArtifacts.weightSpecs
}];
modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], {
type: 'application/json'
})); // If anchor elements are not provided, create them without attaching them
// to parents, so that the downloaded file names can be controlled.
jsonAnchor = this.modelJsonAnchor == null ? document.createElement('a') : this.modelJsonAnchor;
jsonAnchor.download = this.modelJsonFileName;
jsonAnchor.href = modelJsonURL; // Trigger downloads by evoking a click event on the download anchors.
// When multiple downloads are started synchronously, Firefox will only
// save the last one.
_context.next = 15;
return defer$1(function () {
return jsonAnchor.dispatchEvent(new MouseEvent('click'));
});
case 15:
if (!(modelArtifacts.weightData != null)) {
_context.next = 21;
break;
}
weightDataAnchor = this.weightDataAnchor == null ? document.createElement('a') : this.weightDataAnchor;
weightDataAnchor.download = this.weightDataFileName;
weightDataAnchor.href = weightsURL;
_context.next = 21;
return defer$1(function () {
return weightDataAnchor.dispatchEvent(new MouseEvent('click'));
});
case 21:
return _context.abrupt("return", {
modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts)
});
case 22:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function save(_x) {
return _save.apply(this, arguments);
}
return save;
}();
return BrowserDownloads;
}();
BrowserDownloads.URL_SCHEME = 'downloads://';
var BrowserFiles = /*#__PURE__*/function () {
function BrowserFiles(files) {
if (files == null || files.length < 1) {
throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + files));
}
this.jsonFile = files[0];
this.weightsFiles = files.slice(1);
}
var _proto2 = BrowserFiles.prototype;
_proto2.load = /*#__PURE__*/function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var _this = this;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
return _context2.abrupt("return", new Promise(function (resolve, reject) {
var jsonReader = new FileReader();
jsonReader.onload = function (event) {
// tslint:disable-next-line:no-any
var modelJSON = JSON.parse(event.target.result);
var modelTopology = modelJSON.modelTopology;
if (modelTopology == null) {
reject(new Error("modelTopology field is missing from file " + _this.jsonFile.name));
return;
}
var weightsManifest = modelJSON.weightsManifest;
if (weightsManifest == null) {
reject(new Error("weightManifest field is missing from file " + _this.jsonFile.name));
return;
}
if (_this.weightsFiles.length === 0) {
resolve({
modelTopology: modelTopology
});
return;
}
var modelArtifactsPromise = getModelArtifactsForJSON(modelJSON, function (weightsManifest) {
return _this.loadWeights(weightsManifest);
});
resolve(modelArtifactsPromise);
};
jsonReader.onerror = function (error) {
return reject("Failed to read model topology and weights manifest JSON " + ("from file '" + _this.jsonFile.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only.");
};
jsonReader.readAsText(_this.jsonFile);
}));
case 1:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}();
_proto2.loadWeights = function loadWeights(weightsManifest) {
var _this2 = this;
var weightSpecs = [];
var paths = [];
for (var _iterator = _createForOfIteratorHelperLoose(weightsManifest), _step; !(_step = _iterator()).done;) {
var entry = _step.value;
weightSpecs.push.apply(weightSpecs, entry.weights);
paths.push.apply(paths, entry.paths);
}
var pathToFile = this.checkManifestAndWeightFiles(weightsManifest);
var promises = paths.map(function (path) {
return _this2.loadWeightsFile(path, pathToFile[path]);
});
return Promise.all(promises).then(function (buffers) {
return [weightSpecs, concatenateArrayBuffers(buffers)];
});
};
_proto2.loadWeightsFile = function loadWeightsFile(path, file) {
return new Promise(function (resolve, reject) {
var weightFileReader = new FileReader();
weightFileReader.onload = function (event) {
// tslint:disable-next-line:no-any
var weightData = event.target.result;
resolve(weightData);
};
weightFileReader.onerror = function (error) {
return reject("Failed to weights data from file of path '" + path + "'.");
};
weightFileReader.readAsArrayBuffer(file);
});
}
/**
* Check the compatibility between weights manifest and weight files.
*/
;
_proto2.checkManifestAndWeightFiles = function checkManifestAndWeightFiles(manifest) {
var _this3 = this;
var basenames = [];
var fileNames = this.weightsFiles.map(function (file) {
return basename(file.name);
});
var pathToFile = {};
for (var _iterator2 = _createForOfIteratorHelperLoose(manifest), _step2; !(_step2 = _iterator2()).done;) {
var group = _step2.value;
group.paths.forEach(function (path) {
var pathBasename = basename(path);
if (basenames.indexOf(pathBasename) !== -1) {
throw new Error("Duplicate file basename found in weights manifest: " + ("'" + pathBasename + "'"));
}
basenames.push(pathBasename);
if (fileNames.indexOf(pathBasename) === -1) {
throw new Error("Weight file with basename '" + pathBasename + "' is not provided.");
} else {
pathToFile[path] = _this3.weightsFiles[fileNames.indexOf(pathBasename)];
}
});
}
if (basenames.length !== this.weightsFiles.length) {
throw new Error("Mismatch in the number of files in weights manifest " + ("(" + basenames.length + ") and the number of weight files provided ") + ("(" + this.weightsFiles.length + ")."));
}
return pathToFile;
};
return BrowserFiles;
}();
var browserDownloadsRouter = function browserDownloadsRouter(url) {
if (!env().getBool('IS_BROWSER')) {
return null;
} else {
if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
} else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
/**
* Creates an IOHandler that triggers file downloads from the browser.
*
* The returned `IOHandler` instance can be used as model exporting methods such
* as `tf.Model.save` and supports only saving.
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.dense(
* {units: 1, inputShape: [10], activation: 'sigmoid'}));
* const saveResult = await model.save('downloads://mymodel');
* // This will trigger downloading of two files:
* // 'mymodel.json' and 'mymodel.weights.bin'.
* console.log(saveResult);
* ```
*
* @param fileNamePrefix Prefix name of the files to be downloaded. For use with
* `tf.Model`, `fileNamePrefix` should follow either of the following two
* formats:
* 1. `null` or `undefined`, in which case the default file
* names will be used:
* - 'model.json' for the JSON file containing the model topology and
* weights manifest.
* - 'model.weights.bin' for the binary file containing the binary weight
* values.
* 2. A single string or an Array of a single string, as the file name prefix.
* For example, if `'foo'` is provided, the downloaded JSON
* file and binary weights file will be named 'foo.json' and
* 'foo.weights.bin', respectively.
* @param config Additional configuration for triggering downloads.
* @returns An instance of `BrowserDownloads` `IOHandler`.
*
* @doc {
* heading: 'Models',
* subheading: 'Loading',
* namespace: 'io',
* ignoreCI: true
* }
*/
function browserDownloads(fileNamePrefix) {
if (fileNamePrefix === void 0) {
fileNamePrefix = 'model';
}
return new BrowserDownloads(fileNamePrefix);
}
/**
* Creates an IOHandler that loads model artifacts from user-selected files.
*
* This method can be used for loading from files such as user-selected files
* in the browser.
* When used in conjunction with `tf.loadLayersModel`, an instance of
* `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
*
* ```js
* // Note: This code snippet won't run properly without the actual file input
* // elements in the HTML DOM.
*
* // Suppose there are two HTML file input (`<input type="file" ...>`)
* // elements.
* const uploadJSONInput = document.getElementById('upload-json');
* const uploadWeightsInput = document.getElementById('upload-weights');
* const model = await tf.loadLayersModel(tf.io.browserFiles(
* [uploadJSONInput.files[0], uploadWeightsInput.files[0]]));
* ```
*
* @param files `File`s to load from. Currently, this function supports only
* loading from files that contain Keras-style models (i.e., `tf.Model`s), for
* which an `Array` of `File`s is expected (in that order):
* - A JSON file containing the model topology and weight manifest.
* - Optionally, One or more binary files containing the binary weights.
* These files must have names that match the paths in the `weightsManifest`
* contained by the aforementioned JSON file, or errors will be thrown
* during loading. These weights files have the same format as the ones
* generated by `tensorflowjs_converter` that comes with the `tensorflowjs`
* Python PIP package. If no weights files are provided, only the model
* topology will be loaded from the JSON file above.
* @returns An instance of `Files` `IOHandler`.
*
* @doc {
* heading: 'Models',
* subheading: 'Loading',
* namespace: 'io',
* ignoreCI: true
* }
*/
function browserFiles(files) {
return new BrowserFiles(files);
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Monitor Promise.all progress, fire onProgress callback function.
*
* @param promises Promise list going to be monitored
* @param onProgress Callback function. Fired when a promise resolved.
* @param startFraction Optional fraction start. Default to 0.
* @param endFraction Optional fraction end. Default to 1.
*/
function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) {
checkPromises(promises);
startFraction = startFraction == null ? 0 : startFraction;
endFraction = endFraction == null ? 1 : endFraction;
checkFraction(startFraction, endFraction);
var resolvedPromise = 0;
var registerMonitor = function registerMonitor(promise) {
promise.then(function (value) {
var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); // pass fraction as parameter to callback function.
onProgress(fraction);
return value;
});
return promise;
};
function checkPromises(promises) {
assert(promises != null && Array.isArray(promises) && promises.length > 0, function () {
return 'promises must be a none empty array';
});
}
function checkFraction(startFraction, endFraction) {
assert(startFraction >= 0 && startFraction <= 1, function () {
return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + startFraction);
});
assert(endFraction >= 0 && endFraction <= 1, function () {
return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + endFraction);
});
assert(endFraction >= startFraction, function () {
return "startFraction must be no more than endFraction, but " + ("got startFraction " + startFraction + " and endFraction ") + ("" + endFraction);
});
}
return Promise.all(promises.map(registerMonitor));
}
/**
* Reads binary weights data from a number of URLs.
*
* @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls.
* @param requestOptions RequestInit (options) for the HTTP requests.
* @param fetchFunc Optional overriding value for the `window.fetch` function.
* @param onProgress Optional, progress callback function, fired periodically
* before the load is completed.
* @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same
* length as `fetchURLs`.
*/
function loadWeightsAsArrayBuffer(_x, _x2) {
return _loadWeightsAsArrayBuffer.apply(this, arguments);
}
/**
* Reads a weights manifest JSON configuration, fetches the weights and
* returns them as `Tensor`s.
*
* @param manifest The weights manifest JSON.
* @param filePathPrefix The path prefix for filenames given in the manifest.
* Defaults to the empty string.
* @param weightNames The names of the weights to be fetched.
*/
function _loadWeightsAsArrayBuffer() {
_loadWeightsAsArrayBuffer = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(fetchURLs, loadOptions) {
var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, bufferPromises, bufferStartFraction, bufferEndFraction, buffers;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (loadOptions == null) {
loadOptions = {};
}
fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; // Create the requests for all of the weights in parallel.
requests = fetchURLs.map(function (fetchURL) {
return fetchFunc(fetchURL, loadOptions.requestInit, {
isBinary: true
});
});
fetchStartFraction = 0;
fetchEndFraction = 0.5;
if (!(loadOptions.onProgress == null)) {
_context2.next = 11;
break;
}
_context2.next = 8;
return Promise.all(requests);
case 8:
_context2.t0 = _context2.sent;
_context2.next = 14;
break;
case 11:
_context2.next = 13;
return monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction);
case 13:
_context2.t0 = _context2.sent;
case 14:
responses = _context2.t0;
bufferPromises = responses.map(function (response) {
return response.arrayBuffer();
});
bufferStartFraction = 0.5;
bufferEndFraction = 1;
if (!(loadOptions.onProgress == null)) {
_context2.next = 24;
break;
}
_context2.next = 21;
return Promise.all(bufferPromises);
case 21:
_context2.t1 = _context2.sent;
_context2.next = 27;
break;
case 24:
_context2.next = 26;
return monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction);
case 26:
_context2.t1 = _context2.sent;
case 27:
buffers = _context2.t1;
return _context2.abrupt("return", buffers);
case 29:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
return _loadWeightsAsArrayBuffer.apply(this, arguments);
}
function loadWeights(_x3, _x4, _x5, _x6) {
return _loadWeights.apply(this, arguments);
}
/**
* Creates a function, which reads a weights manifest JSON configuration,
* fetches the weight files using the specified function and returns them as
* `Tensor`s.
*
* ```js
* // example for creating a nodejs weight loader, which reads the weight files
* // from disk using fs.readFileSync
*
* import * as fs from 'fs'
*
* const fetchWeightsFromDisk = (filePaths: string[]) =>
* filePaths.map(filePath => fs.readFileSync(filePath).buffer)
*
* const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk)
*
* const manifest = JSON.parse(
* fs.readFileSync('./my_model-weights_manifest').toString()
* )
* const weightMap = await loadWeights(manifest, './')
* ```
* @param fetchWeightsFunction The function used for fetching the weight files.
* @returns Weight loading function.
*/
function _loadWeights() {
_loadWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(manifest, filePathPrefix, weightNames, requestInit) {
var fetchWeights, loadWeights;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (filePathPrefix === void 0) {
filePathPrefix = '';
}
// TODO(nsthorat): Groups are currently fetched atomically. If you need a
// single weight from a group, the whole group will be fetched. At a future
// date, we should support fetching only the individual shards within a
// group that are needed to reconstruct the requested weight.
// TODO(cais): Use `decodeWeights` for implementation.
fetchWeights = function fetchWeights(fetchUrls) {
return loadWeightsAsArrayBuffer(fetchUrls, {
requestInit: requestInit
});
};
loadWeights = weightsLoaderFactory(fetchWeights);
return _context3.abrupt("return", loadWeights(manifest, filePathPrefix, weightNames));
case 4:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
return _loadWeights.apply(this, arguments);
}
function weightsLoaderFactory(fetchWeightsFunction) {
return /*#__PURE__*/function () {
var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(manifest, filePathPrefix, weightNames) {
var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (filePathPrefix === void 0) {
filePathPrefix = '';
}
// Collect all the groups, weights, and their relative offsets to be
// fetched.
groupIndicesToFetchMap = manifest.map(function () {
return false;
});
groupWeightsToFetch = {};
weightsFound = weightNames != null ? weightNames.map(function () {
return false;
}) : [];
allManifestWeightNames = [];
manifest.forEach(function (manifestGroupConfig, groupIndex) {
var groupOffset = 0;
manifestGroupConfig.weights.forEach(function (weightsEntry) {
var rawDtype = 'quantization' in weightsEntry ? weightsEntry.quantization.dtype : weightsEntry.dtype;
var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape);
var enqueueWeightsForFetchingFn = function enqueueWeightsForFetchingFn() {
groupIndicesToFetchMap[groupIndex] = true;
if (groupWeightsToFetch[groupIndex] == null) {
groupWeightsToFetch[groupIndex] = [];
}
groupWeightsToFetch[groupIndex].push({
manifestEntry: weightsEntry,
groupOffset: groupOffset,
sizeBytes: weightsBytes
});
};
if (weightNames != null) {
weightNames.forEach(function (weightName, weightIndex) {
if (weightName === weightsEntry.name) {
enqueueWeightsForFetchingFn();
weightsFound[weightIndex] = true;
}
});
} else {
enqueueWeightsForFetchingFn();
}
allManifestWeightNames.push(weightsEntry.name);
groupOffset += weightsBytes;
});
});
if (weightsFound.every(function (found) {
return found;
})) {
_context.next = 9;
break;
}
weightsNotFound = weightNames.filter(function (_, i) {
return !weightsFound[i];
});
throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(', ') + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(', ') + "."));
case 9:
// Convert the one-hot boolean groupId => shouldFetch map to a list of group
// IDs.
groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) {
if (shouldFetch) {
accumulator.push(i);
}
return accumulator;
}, []);
fetchUrls = [];
groupIndicesToFetch.forEach(function (i) {
manifest[i].paths.forEach(function (filepath) {
var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath;
fetchUrls.push(fetchUrl);
});
});
_context.next = 14;
return fetchWeightsFunction(fetchUrls);
case 14:
buffers = _context.sent;
weightsTensorMap = {};
bufferIndexOffset = 0;
groupIndicesToFetch.forEach(function (i) {
var numBuffers = manifest[i].paths.length;
var groupBytes = 0;
for (var _i = 0; _i < numBuffers; _i++) {
groupBytes += buffers[bufferIndexOffset + _i].byteLength;
} // Create a buffer for the whole group.
var groupBuffer = new ArrayBuffer(groupBytes);
var groupByteBuffer = new Uint8Array(groupBuffer);
var groupBufferOffset = 0;
for (var _i2 = 0; _i2 < numBuffers; _i2++) {
var buffer = new Uint8Array(buffers[bufferIndexOffset + _i2]);
groupByteBuffer.set(buffer, groupBufferOffset);
groupBufferOffset += buffer.byteLength;
}
var weightsEntries = groupWeightsToFetch[i];
weightsEntries.forEach(function (weightsEntry) {
var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes);
var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
for (var name in nameToTensorMap) {
weightsTensorMap[name] = nameToTensorMap[name];
}
});
bufferIndexOffset += numBuffers;
});
return _context.abrupt("return", weightsTensorMap);
case 19:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return function (_x7, _x8, _x9) {
return _ref.apply(this, arguments);
};
}();
}
var OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
var JSON_TYPE = 'application/json';
var HTTPRequest = /*#__PURE__*/function () {
function HTTPRequest(path, loadOptions) {
this.DEFAULT_METHOD = 'POST';
if (loadOptions == null) {
loadOptions = {};
}
this.weightPathPrefix = loadOptions.weightPathPrefix;
this.onProgress = loadOptions.onProgress;
this.weightUrlConverter = loadOptions.weightUrlConverter;
if (loadOptions.fetchFunc != null) {
assert(typeof loadOptions.fetchFunc === 'function', function () {
return 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)';
});
this.fetch = loadOptions.fetchFunc;
} else {
this.fetch = env().platform.fetch;
}
assert(path != null && path.length > 0, function () {
return 'URL path for http must not be null, undefined or ' + 'empty.';
});
if (Array.isArray(path)) {
assert(path.length === 2, function () {
return 'URL paths for http must have a length of 2, ' + ("(actual length is " + path.length + ").");
});
}
this.path = path;
if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) {
throw new Error('requestInit is expected to have no pre-existing body, but has one.');
}
this.requestInit = loadOptions.requestInit || {};
}
var _proto = HTTPRequest.prototype;
_proto.save = /*#__PURE__*/function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(modelArtifacts) {
var init, weightsManifest, modelTopologyAndWeightManifest, response;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) {
_context.next = 2;
break;
}
throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + 'in binary formats yet.');
case 2:
init = Object.assign({
method: this.DEFAULT_METHOD
}, this.requestInit);
init.body = new FormData();
weightsManifest = [{
paths: ['./model.weights.bin'],
weights: modelArtifacts.weightSpecs
}];
modelTopologyAndWeightManifest = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], {
type: JSON_TYPE
}), 'model.json');
if (modelArtifacts.weightData != null) {
init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], {
type: OCTET_STREAM_MIME_TYPE
}), 'model.weights.bin');
}
_context.next = 10;
return this.fetch(this.path, init);
case 10:
response = _context.sent;
if (!response.ok) {
_context.next = 15;
break;
}
return _context.abrupt("return", {
modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts),
responses: [response]
});
case 15:
throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (response.status + "."));
case 16:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function save(_x) {
return _save.apply(this, arguments);
}
return save;
}()
/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
*/
;
_proto.load =
/*#__PURE__*/
function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var _this = this;
var modelConfigRequest, modelJSON, message, modelTopology, weightsManifest;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.fetch(this.path, this.requestInit);
case 2:
modelConfigRequest = _context2.sent;
if (modelConfigRequest.ok) {
_context2.next = 5;
break;
}
throw new Error("Request to " + this.path + " failed with status code " + (modelConfigRequest.status + ". Please verify this URL points to ") + "the model JSON of the model to load.");
case 5:
_context2.prev = 5;
_context2.next = 8;
return modelConfigRequest.json();
case 8:
modelJSON = _context2.sent;
_context2.next = 16;
break;
case 11:
_context2.prev = 11;
_context2.t0 = _context2["catch"](5);
message = "Failed to parse model JSON of response from " + this.path + "."; // TODO(nsthorat): Remove this after some time when we're comfortable that
// .pb files are mostly gone.
if (this.path.endsWith('.pb')) {
message += ' Your path contains a .pb file extension. ' + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + 'in favor of .json models. You can re-convert your Python ' + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + 'or you can convert your.pb models with the \'pb2json\'' + 'NPM script in the tensorflow/tfjs-converter repository.';
} else {
message += ' Please make sure the server is serving valid ' + 'JSON for this request.';
}
throw new Error(message);
case 16:
// We do not allow both modelTopology and weightsManifest to be missing.
modelTopology = modelJSON.modelTopology;
weightsManifest = modelJSON.weightsManifest;
if (!(modelTopology == null && weightsManifest == null)) {
_context2.next = 20;
break;
}
throw new Error("The JSON from HTTP path " + this.path + " contains neither model " + "topology or manifest for weights.");
case 20:
return _context2.abrupt("return", getModelArtifactsForJSON(modelJSON, function (weightsManifest) {
return _this.loadWeights(weightsManifest);
}));
case 21:
case "end":
return _context2.stop();
}
}
}, _callee2, this, [[5, 11]]);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}();
_proto.loadWeights = /*#__PURE__*/function () {
var _loadWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(weightsManifest) {
var weightPath, _parseUrl, prefix, suffix, pathPrefix, weightSpecs, _iterator, _step, entry, fetchURLs, urlPromises, _iterator2, _step2, weightsGroup, _iterator3, _step3, path, buffers;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
weightPath = Array.isArray(this.path) ? this.path[1] : this.path;
_parseUrl = parseUrl(weightPath), prefix = _parseUrl[0], suffix = _parseUrl[1];
pathPrefix = this.weightPathPrefix || prefix;
weightSpecs = [];
for (_iterator = _createForOfIteratorHelperLoose(weightsManifest); !(_step = _iterator()).done;) {
entry = _step.value;
weightSpecs.push.apply(weightSpecs, entry.weights);
}
fetchURLs = [];
urlPromises = [];
for (_iterator2 = _createForOfIteratorHelperLoose(weightsManifest); !(_step2 = _iterator2()).done;) {
weightsGroup = _step2.value;
for (_iterator3 = _createForOfIteratorHelperLoose(weightsGroup.paths); !(_step3 = _iterator3()).done;) {
path = _step3.value;
if (this.weightUrlConverter != null) {
urlPromises.push(this.weightUrlConverter(path));
} else {
fetchURLs.push(pathPrefix + path + suffix);
}
}
}
if (!this.weightUrlConverter) {
_context3.next = 15;
break;
}
_context3.t0 = fetchURLs.push;
_context3.t1 = fetchURLs;
_context3.next = 13;
return Promise.all(urlPromises);
case 13:
_context3.t2 = _context3.sent;
_context3.t0.apply.call(_context3.t0, _context3.t1, _context3.t2);
case 15:
_context3.next = 17;
return loadWeightsAsArrayBuffer(fetchURLs, {
requestInit: this.requestInit,
fetchFunc: this.fetch,
onProgress: this.onProgress
});
case 17:
buffers = _context3.sent;
return _context3.abrupt("return", [weightSpecs, concatenateArrayBuffers(buffers)]);
case 19:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function loadWeights(_x2) {
return _loadWeights.apply(this, arguments);
}
return loadWeights;
}();
return HTTPRequest;
}();
HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//;
/**
* Extract the prefix and suffix of the url, where the prefix is the path before
* the last file, and suffix is the search params after the last file.
* ```
* const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file'
* [prefix, suffix] = parseUrl(url)
* // prefix = 'http://tfhub.dev/model/1/'
* // suffix = '?tfjs-format=file'
* ```
* @param url the model url to be parsed.
*/
function parseUrl(url) {
var lastSlash = url.lastIndexOf('/');
var lastSearchParam = url.lastIndexOf('?');
var prefix = url.substring(0, lastSlash);
var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : '';
return [prefix + '/', suffix];
}
function isHTTPScheme(url) {
return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
}
var httpRouter = function httpRouter(url, loadOptions) {
if (typeof fetch === 'undefined' && (loadOptions == null || loadOptions.fetchFunc == null)) {
// `http` uses `fetch` or `node-fetch`, if one wants to use it in
// an environment that is not the browser or node they have to setup a
// global fetch polyfill.
return null;
} else {
var isHTTP = true;
if (Array.isArray(url)) {
isHTTP = url.every(function (urlItem) {
return isHTTPScheme(urlItem);
});
} else {
isHTTP = isHTTPScheme(url);
}
if (isHTTP) {
return http(url, loadOptions);
}
}
return null;
};
IORouterRegistry.registerSaveRouter(httpRouter);
IORouterRegistry.registerLoadRouter(httpRouter);
/**
* Creates an IOHandler subtype that sends model artifacts to HTTP server.
*
* An HTTP request of the `multipart/form-data` mime type will be sent to the
* `path` URL. The form data includes artifacts that represent the topology
* and/or weights of the model. In the case of Keras-style `tf.Model`, two
* blobs (files) exist in form-data:
* - A JSON file consisting of `modelTopology` and `weightsManifest`.
* - A binary weights file consisting of the concatenated weight values.
* These files are in the same format as the one generated by
* [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
*
* The following code snippet exemplifies the client-side code that uses this
* function:
*
* ```js
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
*
* const saveResult = await model.save(tf.io.http(
* 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}}));
* console.log(saveResult);
* ```
*
* If the default `POST` method is to be used, without any custom parameters
* such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`:
*
* ```js
* const saveResult = await model.save('http://model-server:5000/upload');
* ```
*
* The following GitHub Gist
* https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864
* implements a server based on [flask](https://github.com/pallets/flask) that
* can receive the request. Upon receiving the model artifacts via the requst,
* this particular server reconsistutes instances of [Keras
* Models](https://keras.io/models/model/) in memory.
*
*
* @param path A URL path to the model.
* Can be an absolute HTTP path (e.g.,
* 'http://localhost:8000/model-upload)') or a relative path (e.g.,
* './model-upload').
* @param requestInit Request configurations to be used when sending
* HTTP request to server using `fetch`. It can contain fields such as
* `method`, `credentials`, `headers`, `mode`, etc. See
* https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
* for more information. `requestInit` must not have a body, because the
* body will be set by TensorFlow.js. File blobs representing the model
* topology (filename: 'model.json') and the weights of the model (filename:
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
* `body`, an Error will be thrown.
* @param loadOptions Optional configuration for the loading. It includes the
* following fields:
* - weightPathPrefix Optional, this specifies the path prefix for weight
* files, by default this is calculated from the path param.
* - fetchFunc Optional, custom `fetch` function. E.g., in Node.js,
* the `fetch` from node-fetch can be used here.
* - onProgress Optional, progress callback function, fired periodically
* before the load is completed.
* @returns An instance of `IOHandler`.
*
* @doc {
* heading: 'Models',
* subheading: 'Loading',
* namespace: 'io',
* ignoreCI: true
* }
*/
function http(path, loadOptions) {
return new HTTPRequest(path, loadOptions);
}
/**
* Deprecated. Use `tf.io.http`.
* @param path
* @param loadOptions
*/
function browserHTTPRequest(path, loadOptions) {
return http(path, loadOptions);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PassthroughLoader = /*#__PURE__*/function () {
function PassthroughLoader(modelArtifacts) {
this.modelArtifacts = modelArtifacts;
}
var _proto = PassthroughLoader.prototype;
_proto.load = /*#__PURE__*/function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", this.modelArtifacts);
case 1:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}();
return PassthroughLoader;
}();
var PassthroughSaver = /*#__PURE__*/function () {
function PassthroughSaver(saveHandler) {
this.saveHandler = saveHandler;
}
var _proto2 = PassthroughSaver.prototype;
_proto2.save = /*#__PURE__*/function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(modelArtifacts) {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
return _context2.abrupt("return", this.saveHandler(modelArtifacts));
case 1:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function save(_x) {
return _save.apply(this, arguments);
}
return save;
}();
return PassthroughSaver;
}();
/**
* Creates an IOHandler that loads model artifacts from memory.
*
* When used in conjunction with `tf.loadLayersModel`, an instance of
* `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts.
*
* ```js
* const model = await tf.loadLayersModel(tf.io.fromMemory(
* modelTopology, weightSpecs, weightData));
* ```
*
* @param modelArtifacts a object containing model topology (i.e., parsed from
* the JSON format).
* @param weightSpecs An array of `WeightsManifestEntry` objects describing the
* names, shapes, types, and quantization of the weight data.
* @param weightData A single `ArrayBuffer` containing the weight data,
* concatenated in the order described by the weightSpecs.
* @param trainingConfig Model training configuration. Optional.
*
* @returns A passthrough `IOHandler` that simply loads the provided data.
*/
function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
if (arguments.length === 1) {
var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null;
if (isModelArtifacts) {
return new PassthroughLoader(modelArtifacts);
} else {
// Legacy support: with only modelTopology.
// TODO(cais): Remove this deprecated API.
console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.');
return new PassthroughLoader({
modelTopology: modelArtifacts
});
}
} else {
// Legacy support.
// TODO(cais): Remove this deprecated API.
console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.');
return new PassthroughLoader({
modelTopology: modelArtifacts,
weightSpecs: weightSpecs,
weightData: weightData,
trainingConfig: trainingConfig
});
}
}
/**
* Creates an IOHandler that passes saved model artifacts to a callback.
*
* ```js
* function handleSave(artifacts) {
* // ... do something with the artifacts ...
* return {modelArtifactsInfo: {...}, ...};
* }
*
* const saveResult = model.save(tf.io.withSaveHandler(handleSave));
* ```
*
* @param saveHandler A function that accepts a `ModelArtifacts` and returns a
* `SaveResult`.
*/
function withSaveHandler(saveHandler) {
return new PassthroughSaver(saveHandler);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var io = {
__proto__: null,
browserFiles: browserFiles,
browserHTTPRequest: browserHTTPRequest,
concatenateArrayBuffers: concatenateArrayBuffers,
decodeWeights: decodeWeights,
encodeWeights: encodeWeights,
fromMemory: fromMemory,
getLoadHandlers: getLoadHandlers,
getModelArtifactsForJSON: getModelArtifactsForJSON,
getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON,
getSaveHandlers: getSaveHandlers,
http: http,
isHTTPScheme: isHTTPScheme,
loadWeights: loadWeights,
registerLoadRouter: registerLoadRouter,
registerSaveRouter: registerSaveRouter,
weightsLoaderFactory: weightsLoaderFactory,
withSaveHandler: withSaveHandler,
copyModel: copyModel,
listModels: listModels,
moveModel: moveModel,
removeModel: removeModel
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the dot product of two matrices, A * B. These must be matrices.
*
* ```js
* const a = tf.tensor2d([1, 2], [1, 2]);
* const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* a.matMul(b).print(); // or tf.matMul(a, b)
* ```
* @param a First matrix in dot product operation.
* @param b Second matrix in dot product operation.
* @param transposeA If true, `a` is transposed before multiplication.
* @param transposeB If true, `b` is transposed before multiplication.
*
* @doc {heading: 'Operations', subheading: 'Matrices'}
*/
function matMul_(a, b, transposeA, transposeB) {
if (transposeA === void 0) {
transposeA = false;
}
if (transposeB === void 0) {
transposeB = false;
}
var $a = convertToTensor(a, 'a', 'matMul');
var $b = convertToTensor(b, 'b', 'matMul');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
var attrs = {
transposeA: transposeA,
transposeB: transposeB
};
return ENGINE.runKernel(BatchMatMul, inputs, attrs);
}
var matMul = op({
matMul_: matMul_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
* value `onValue` (defaults to 1), while all other locations take value
* `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
* `R+1` with the last axis of size `depth`.
*
* ```js
* tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
* ```
*
* @param indices `tf.Tensor` of indices with dtype `int32`.
* @param depth The depth of the one hot dimension.
* @param onValue A number used to fill in the output when the index matches
* the location.
* @param offValue A number used to fill in the output when the index does
* not match the location.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function oneHot_(indices, depth, onValue, offValue) {
if (onValue === void 0) {
onValue = 1;
}
if (offValue === void 0) {
offValue = 0;
}
if (depth < 2) {
throw new Error("Error in oneHot: depth must be >=2, but it is " + depth);
}
var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
var inputs = {
indices: $indices
};
var attrs = {
depth: depth,
onValue: onValue,
offValue: offValue
};
return ENGINE.runKernel(OneHot, inputs, attrs);
}
var oneHot = op({
oneHot_: oneHot_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`.
*
* The returned `tf.Tensor`'s dimension `i` will correspond to the input
* dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`,
* where `n` is the rank of the input `tf.Tensor`. Hence by default, this
* operation performs a regular matrix transpose on 2-D input `tf.Tensor`s.
*
* ```js
* const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
*
* a.transpose().print(); // or tf.transpose(a)
* ```
*
* @param x The tensor to transpose.
* @param perm The permutation of the dimensions of a.
*
* @doc {heading: 'Operations', subheading: 'Matrices'}
*/
function transpose_(x, perm) {
var $x = convertToTensor(x, 'x', 'transpose');
if (perm == null) {
perm = $x.shape.map(function (s, i) {
return i;
}).reverse();
}
assert($x.rank === perm.length, function () {
return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of perm " + perm + ".");
});
perm.forEach(function (axis) {
assert(axis >= 0 && axis < $x.rank, function () {
return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + (" but got " + perm);
});
});
if ($x.rank <= 1) {
return $x.clone();
}
var inputs = {
x: $x
};
var attrs = {
perm: perm
};
return ENGINE.runKernel(Transpose, inputs, attrs);
}
var transpose = op({
transpose_: transpose_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the confusion matrix from true labels and predicted labels.
*
* ```js
* const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32');
* const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32');
* const numClasses = 3;
* const out = tf.math.confusionMatrix(labels, predictions, numClasses);
* out.print();
* // Expected output matrix:
* // [[2, 0, 0],
* // [0, 1, 1],
* // [0, 0, 1]]
* ```
*
* @param labels The target labels, assumed to be 0-based integers
* for the classes. The shape is `[numExamples]`, where
* `numExamples` is the number of examples included.
* @param predictions The predicted classes, assumed to be
* 0-based integers for the classes. Must have the same shape as `labels`.
* @param numClasses Number of all classes, as an integer.
* Its value must be larger than the largest element in `labels` and
* `predictions`.
* @returns The confusion matrix as a int32-type 2D tensor. The value at
* row `r` and column `c` is the number of times examples of actual class
* `r` were predicted as class `c`.
*
* @doc {heading: 'Operations', subheading: 'Evaluation'}
*/
function confusionMatrix_(labels, predictions, numClasses) {
var $labels = convertToTensor(labels, 'labels', 'confusionMatrix');
var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix');
assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () {
return "If provided, numClasses must be a positive integer, " + ("but got " + numClasses);
});
assert($labels.rank === 1, function () {
return "Expected the rank of labels to be 1, but got " + $labels.rank;
});
assert($predictions.rank === 1, function () {
return "Expected the rank of predictions to be 1, " + ("but got " + $predictions.rank);
});
assert($labels.shape[0] === $predictions.shape[0], function () {
return "Mismatch in the number of examples: " + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + "Labels and predictions should have the same number of elements.";
});
assert(numClasses > 0 && Number.isInteger(numClasses), function () {
return "numClasses is required to be a positive integer, but got " + ("" + numClasses);
}); // TODO(cais): In the future, if oneHot supports tensors inputs for
// `numClasses`, `confusionMatrix` can make `numClasses` optional.
var oneHotLabels = oneHot(cast($labels, 'int32'), numClasses);
var oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses);
var oneHotLabelsT = transpose(oneHotLabels);
var product = matMul(oneHotLabelsT, oneHotPredictions);
return cast(product, 'int32');
}
var confusionMatrix = op({
confusionMatrix_: confusionMatrix_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var math = {
__proto__: null,
confusionMatrix: confusionMatrix
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-3 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor3d` as it makes the code more readable.
*
* ```js
* // Pass a nested array.
* tf.tensor3d([[[1], [2]], [[3], [4]]]).print();
* ```
* ```js
* // Pass a flat array and specify a shape.
* tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`.
* @param shape The shape of the tensor. If not provided, it is inferred from
* `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor3d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 3) {
throw new Error('tensor3d() requires shape to have three numbers');
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 3 && inferredShape.length !== 1) {
throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor3d() requires shape to be provided when `values` ' + 'are a flat array');
}
return makeTensor(values, shape, inferredShape, dtype);
}
var fromPixels2DContext;
/**
* Creates a `tf.Tensor` from an image.
*
* ```js
* const image = new ImageData(1, 1);
* image.data[0] = 100;
* image.data[1] = 150;
* image.data[2] = 200;
* image.data[3] = 255;
*
* tf.browser.fromPixels(image).print();
* ```
*
* @param pixels The input image to construct the tensor from. The
* supported image types are all 4-channel. You can also pass in an image
* object with following attributes:
* `{data: Uint8Array; width: number; height: number}`
* @param numChannels The number of channels of the output tensor. A
* numChannels value less than 4 allows you to ignore channels. Defaults to
* 3 (ignores alpha channel of input image).
*
* @returns A Tensor3D with the shape `[height, width, numChannels]`.
*
* @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
*/
function fromPixels_(pixels, numChannels) {
if (numChannels === void 0) {
numChannels = 3;
}
// Sanity checks.
if (numChannels > 4) {
throw new Error('Cannot construct Tensor with more than 4 channels from pixels.');
}
if (pixels == null) {
throw new Error('pixels passed to tf.browser.fromPixels() can not be null');
}
var isPixelData = false;
var isImageData = false;
var isVideo = false;
var isImage = false;
var isCanvasLike = false;
var isImageBitmap = false;
if (pixels.data instanceof Uint8Array) {
isPixelData = true;
} else if (typeof ImageData !== 'undefined' && pixels instanceof ImageData) {
isImageData = true;
} else if (typeof HTMLVideoElement !== 'undefined' && pixels instanceof HTMLVideoElement) {
isVideo = true;
} else if (typeof HTMLImageElement !== 'undefined' && pixels instanceof HTMLImageElement) {
isImage = true; // tslint:disable-next-line: no-any
} else if (pixels.getContext != null) {
isCanvasLike = true;
} else if (typeof ImageBitmap !== 'undefined' && pixels instanceof ImageBitmap) {
isImageBitmap = true;
} else {
throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + "in browser, or OffscreenCanvas, ImageData in webworker" + " or {data: Uint32Array, width: number, height: number}, " + ("but was " + pixels.constructor.name));
}
if (isVideo) {
var HAVE_CURRENT_DATA_READY_STATE = 2;
if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) {
throw new Error('The video element has not loaded data yet. Please wait for ' + '`loadeddata` event on the <video> element.');
}
} // If the current backend has 'FromPixels' registered, it has a more
// efficient way of handling pixel uploads, so we call that.
var kernel = getKernel(FromPixels, ENGINE.backendName);
if (kernel != null) {
var inputs = {
pixels: pixels
};
var attrs = {
numChannels: numChannels
};
return ENGINE.runKernel(FromPixels, inputs, attrs);
}
var _ref = isVideo ? [pixels.videoWidth, pixels.videoHeight] : [pixels.width, pixels.height],
width = _ref[0],
height = _ref[1];
var vals;
if (isCanvasLike) {
vals = // tslint:disable-next-line:no-any
pixels.getContext('2d').getImageData(0, 0, width, height).data;
} else if (isImageData || isPixelData) {
vals = pixels.data;
} else if (isImage || isVideo || isImageBitmap) {
if (fromPixels2DContext == null) {
fromPixels2DContext = document.createElement('canvas').getContext('2d');
}
fromPixels2DContext.canvas.width = width;
fromPixels2DContext.canvas.height = height;
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
vals = fromPixels2DContext.getImageData(0, 0, width, height).data;
}
var values;
if (numChannels === 4) {
values = new Int32Array(vals);
} else {
var numPixels = width * height;
values = new Int32Array(numPixels * numChannels);
for (var i = 0; i < numPixels; i++) {
for (var channel = 0; channel < numChannels; ++channel) {
values[i * numChannels + channel] = vals[i * 4 + channel];
}
}
}
var outShape = [height, width, numChannels];
return tensor3d(values, outShape, 'int32');
} // Helper functions for |fromPixelsAsync| to check whether the input can
// be wrapped into imageBitmap.
function isPixelData(pixels) {
return pixels != null && pixels.data instanceof Uint8Array;
}
function isImageBitmapFullySupported() {
return typeof window !== 'undefined' && typeof ImageBitmap !== 'undefined' && window.hasOwnProperty('createImageBitmap');
}
function isNonEmptyPixels(pixels) {
return pixels != null && pixels.width !== 0 && pixels.height !== 0;
}
function canWrapPixelsToImageBitmap(pixels) {
return isImageBitmapFullySupported() && !(pixels instanceof ImageBitmap) && isNonEmptyPixels(pixels) && !isPixelData(pixels);
}
/**
* Creates a `tf.Tensor` from an image in async way.
*
* ```js
* const image = new ImageData(1, 1);
* image.data[0] = 100;
* image.data[1] = 150;
* image.data[2] = 200;
* image.data[3] = 255;
*
* (await tf.browser.fromPixelsAsync(image)).print();
* ```
* This API is the async version of fromPixels. The API will first
* check |WRAP_TO_IMAGEBITMAP| flag, and try to wrap the input to
* imageBitmap if the flag is set to true.
*
* @param pixels The input image to construct the tensor from. The
* supported image types are all 4-channel. You can also pass in an image
* object with following attributes:
* `{data: Uint8Array; width: number; height: number}`
* @param numChannels The number of channels of the output tensor. A
* numChannels value less than 4 allows you to ignore channels. Defaults to
* 3 (ignores alpha channel of input image).
*
* @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true}
*/
function fromPixelsAsync(_x, _x2) {
return _fromPixelsAsync.apply(this, arguments);
}
/**
* Draws a `tf.Tensor` of pixel values to a byte array or optionally a
* canvas.
*
* When the dtype of the input is 'float32', we assume values in the range
* [0-1]. Otherwise, when input is 'int32', we assume values in the range
* [0-255].
*
* Returns a promise that resolves when the canvas has been drawn to.
*
* @param img A rank-2 tensor with shape `[height, width]`, or a rank-3 tensor
* of shape `[height, width, numChannels]`. If rank-2, draws grayscale. If
* rank-3, must have depth of 1, 3 or 4. When depth of 1, draws
* grayscale. When depth of 3, we draw with the first three components of
* the depth dimension corresponding to r, g, b and alpha = 1. When depth of
* 4, all four components of the depth dimension correspond to r, g, b, a.
* @param canvas The canvas to draw to.
*
* @doc {heading: 'Browser', namespace: 'browser'}
*/
function _fromPixelsAsync() {
_fromPixelsAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(pixels, numChannels) {
var inputs, imageBitmap;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (numChannels === void 0) {
numChannels = 3;
}
inputs = null; // Check whether the backend needs to wrap |pixels| to imageBitmap and
// whether |pixels| can be wrapped to imageBitmap.
if (!(env().getBool('WRAP_TO_IMAGEBITMAP') && canWrapPixelsToImageBitmap(pixels))) {
_context.next = 15;
break;
}
_context.prev = 3;
_context.next = 6;
return createImageBitmap(pixels, {
premultiplyAlpha: 'none'
});
case 6:
imageBitmap = _context.sent;
_context.next = 12;
break;
case 9:
_context.prev = 9;
_context.t0 = _context["catch"](3);
imageBitmap = null;
case 12:
// createImageBitmap will clip the source size.
// In some cases, the input will have larger size than its content.
// E.g. new Image(10, 10) but with 1 x 1 content. Using
// createImageBitmap will clip the size from 10 x 10 to 1 x 1, which
// is not correct. We should avoid wrapping such resouce to
// imageBitmap.
if (imageBitmap != null && imageBitmap.width === pixels.width && imageBitmap.height === pixels.height) {
inputs = imageBitmap;
} else {
inputs = pixels;
}
_context.next = 16;
break;
case 15:
inputs = pixels;
case 16:
return _context.abrupt("return", fromPixels_(inputs, numChannels));
case 17:
case "end":
return _context.stop();
}
}
}, _callee, null, [[3, 9]]);
}));
return _fromPixelsAsync.apply(this, arguments);
}
function toPixels(_x3, _x4) {
return _toPixels.apply(this, arguments);
}
function _toPixels() {
_toPixels = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(img, canvas) {
var $img, originalImgTensor, _$img$shape$slice, height, width, depth, data, multiplier, bytes, i, rgba, d, value, j, ctx, imageData;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
$img = convertToTensor(img, 'img', 'toPixels');
if (!(img instanceof Tensor)) {
// Assume int32 if user passed a native array.
originalImgTensor = $img;
$img = cast(originalImgTensor, 'int32');
originalImgTensor.dispose();
}
if (!($img.rank !== 2 && $img.rank !== 3)) {
_context2.next = 4;
break;
}
throw new Error("toPixels only supports rank 2 or 3 tensors, got rank " + $img.rank + ".");
case 4:
_$img$shape$slice = $img.shape.slice(0, 2), height = _$img$shape$slice[0], width = _$img$shape$slice[1];
depth = $img.rank === 2 ? 1 : $img.shape[2];
if (!(depth > 4 || depth === 2)) {
_context2.next = 8;
break;
}
throw new Error("toPixels only supports depth of size " + ("1, 3 or 4 but got " + depth));
case 8:
if (!($img.dtype !== 'float32' && $img.dtype !== 'int32')) {
_context2.next = 10;
break;
}
throw new Error("Unsupported type for toPixels: " + $img.dtype + "." + " Please use float32 or int32 tensors.");
case 10:
_context2.next = 12;
return $img.data();
case 12:
data = _context2.sent;
multiplier = $img.dtype === 'float32' ? 255 : 1;
bytes = new Uint8ClampedArray(width * height * 4);
i = 0;
case 16:
if (!(i < height * width)) {
_context2.next = 41;
break;
}
rgba = [0, 0, 0, 255];
d = 0;
case 19:
if (!(d < depth)) {
_context2.next = 33;
break;
}
value = data[i * depth + d];
if (!($img.dtype === 'float32')) {
_context2.next = 26;
break;
}
if (!(value < 0 || value > 1)) {
_context2.next = 24;
break;
}
throw new Error("Tensor values for a float32 Tensor must be in the " + ("range [0 - 1] but encountered " + value + "."));
case 24:
_context2.next = 29;
break;
case 26:
if (!($img.dtype === 'int32')) {
_context2.next = 29;
break;
}
if (!(value < 0 || value > 255)) {
_context2.next = 29;
break;
}
throw new Error("Tensor values for a int32 Tensor must be in the " + ("range [0 - 255] but encountered " + value + "."));
case 29:
if (depth === 1) {
rgba[0] = value * multiplier;
rgba[1] = value * multiplier;
rgba[2] = value * multiplier;
} else {
rgba[d] = value * multiplier;
}
case 30:
d++;
_context2.next = 19;
break;
case 33:
j = i * 4;
bytes[j + 0] = Math.round(rgba[0]);
bytes[j + 1] = Math.round(rgba[1]);
bytes[j + 2] = Math.round(rgba[2]);
bytes[j + 3] = Math.round(rgba[3]);
case 38:
++i;
_context2.next = 16;
break;
case 41:
if (canvas != null) {
canvas.width = width;
canvas.height = height;
ctx = canvas.getContext('2d');
imageData = new ImageData(bytes, width, height);
ctx.putImageData(imageData, 0, 0);
}
if ($img !== img) {
$img.dispose();
}
return _context2.abrupt("return", bytes);
case 44:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
return _toPixels.apply(this, arguments);
}
var fromPixels = op({
fromPixels_: fromPixels_
});
var browser = {
__proto__: null,
fromPixelsAsync: fromPixelsAsync,
toPixels: toPixels,
fromPixels: fromPixels
};
/**
* Validate gather nd inputs.
*
* @param tensor The tensor contains the source values.
* @param indices The tensor contains the indices to slice the source.
*
* @returns [resultShape, numUpdates, sliceSize, strides]
*/
function prepareAndValidate(tensor, indices) {
var tensorRank = tensor.shape.length;
var indicesRank = indices.shape.length;
if (tensorRank < 1) {
throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + (" but the rank was " + tensorRank + "."));
}
if (indicesRank < 1) {
throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + (" but the rank was " + indicesRank + "."));
}
if (indices.dtype !== 'int32') {
throw new Error('tf.gatherND() expects the indices to be int32 type,' + (" but the dtype was " + indices.dtype + "."));
}
if (indices.shape[indicesRank - 1] > tensorRank) {
throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + (indices.shape[indicesRank - 1] + " vs. " + tensorRank));
}
if (sizeFromShape(tensor.shape) === 0) {
throw new Error('Requested more than 0 entries, but input is empty.' + (" Input shape: " + tensor.shape + "."));
}
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1]; // The result shape is
// indices.shape[:-1] + params.shape[indices.shape[-1]:]
var nResult = 1;
for (var i = 0; i < indicesShape.length - 1; ++i) {
nResult *= indicesShape[i];
}
var inputShape = tensor.shape;
var resultShape = indicesShape.slice();
resultShape.pop();
var sliceSize = 1;
for (var _i = sliceRank; _i < tensorRank; ++_i) {
sliceSize *= inputShape[_i];
resultShape.push(inputShape[_i]);
}
var strides = [].concat(computeStrides(tensor.shape).map(function (stride) {
return stride / sliceSize;
}), [1]).slice(0, sliceRank);
return [resultShape, nResult, sliceSize, strides];
}
var gather_nd_util = {
__proto__: null,
prepareAndValidate: prepareAndValidate
};
/**
* Check whether updates.shape = indices.shape[:batchDim] +
* shape[sliceDim:]
*
* @param x The input tensor.
*/
function validateUpdateShape(shape, indices, updates) {
var sliceDim = indices.rank > 1 ? indices.shape[indices.rank - 1] : 1;
var batchDim = indices.rank > 1 ? indices.rank - 1 : 1;
var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + ("shape[sliceDim:], got updates.shape: " + updates.shape) + (", indices.shape: " + indices.shape + ", shape: " + shape) + (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + ".");
if (updates.rank < batchDim) {
throw new Error(shapeError + (" update.rank < " + batchDim + ". "));
}
if (shape.length < sliceDim + (updates.rank - batchDim)) {
throw new Error(shapeError + (" Output shape length < " + (sliceDim + (updates.rank - batchDim))));
}
if (updates.rank !== batchDim + shape.length - sliceDim) {
throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim)));
}
for (var d = 0; d < batchDim; ++d) {
if (updates.shape[d] !== indices.shape[d]) {
throw new Error(shapeError + (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ")."));
}
}
for (var _d = 0; _d < updates.rank - batchDim; ++_d) {
if (updates.shape[_d + batchDim] !== shape[_d + sliceDim]) {
throw new Error(shapeError + (" updates.shape[" + (_d + batchDim) + "] (" + updates.shape[_d + batchDim] + ") != shape[" + (_d + batchDim) + "] (" + shape[_d + batchDim] + ")"));
}
}
}
/**
* Validate scatter nd inputs.
*
* @param update The tensor contains the update values.
* @param indices The tensor contains the indices for the update values.
* @param shape The shape of the output tensor.
*/
function validateInput(updates, indices, shape) {
if (indices.rank < 1) {
throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + (" but the rank was " + indices.rank + "."));
}
if (updates.rank < 1) {
throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + (" but the rank was " + updates.rank + "."));
}
if (indices.dtype !== 'int32') {
throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype);
}
if (shape.length < 1) {
throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape);
}
if (shape.length === 0) {
if (indices.size === 0) {
throw new Error("Indices specified for empty output. indices shape: " + indices.shape);
}
if (updates.size === 0) {
throw new Error("Updates specified for empty output. updates shape: " + updates.shape);
}
}
validateUpdateShape(shape, indices, updates);
}
/**
* Calculate the shape information for the output.
*
* @param update The tensor contains the update values.
* @param indices The tensor contains the indices for the update values.
* @param shape The shape of the output tensor.
*
* @returns ScatterShapeInfo
*/
function calculateShapes(updates, indices, shape) {
// Calculate the number of dimensions in indices
var indicesRank = indices.shape.length;
var sliceRank = indicesRank > 1 ? indices.shape[indicesRank - 1] : 1; // Calculate the number of elements that make up each slice of our updated
// tensor. This allows us to work with flattened tensors and copy over whole
// slices at a time.
var totalNd = shape.length;
var sliceSize = 1;
for (var i = sliceRank; i < totalNd; ++i) {
sliceSize *= shape[i];
}
var safeSliceDim = sliceRank < 1 ? 1 : sliceRank;
var numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
var strides = [].concat(computeStrides(shape.slice(0, sliceRank)), [1]);
var outputSize = sizeFromShape(shape);
return {
sliceRank: sliceRank,
numUpdates: numUpdates,
sliceSize: sliceSize,
strides: strides,
outputSize: outputSize
};
}
var scatter_nd_util = {
__proto__: null,
validateUpdateShape: validateUpdateShape,
validateInput: validateInput,
calculateShapes: calculateShapes
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertParamsValid(input, begin, size) {
var inputRank = input.shape.length;
assert(inputRank === begin.length, function () {
return "Error in slice" + inputRank + "D: Length of begin " + begin + " must " + ("match the rank of the array (" + inputRank + ").");
});
assert(inputRank === size.length, function () {
return "Error in slice" + inputRank + "D: Length of size " + size + " must " + ("match the rank of the array (" + inputRank + ").");
});
var _loop = function _loop(i) {
assert(begin[i] + size[i] <= input.shape[i], function () {
return "Error in slice" + inputRank + "D: begin[" + i + "] + size[" + i + "] " + ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")");
});
};
for (var i = 0; i < inputRank; ++i) {
_loop(i);
}
}
/** Converts a binary mask to an array of axes. Used in stridedSlice(). */
function maskToAxes(mask) {
var axes = [];
var axis = 0;
while (mask > 0) {
if (mask & 1) {
axes.push(axis);
}
mask /= 2;
axis++;
}
return axes;
}
/** Computes the output shape given the strided slice params. */
function computeOutShape(begin, end, strides) {
var size = [];
for (var axis = 0; axis < begin.length; axis++) {
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
}
return size;
} // Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current stride value. Otherwise, insert.
function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
var newStrides = [].concat(strides);
for (var i = newStrides.length; i < inputShape.length; i++) {
newStrides.push(1);
}
for (var _i = 0; _i < numElidedAxes; _i++) {
if (_i === 0) {
newStrides[ellipsisInsertionIndex] = 1;
} else {
newStrides.splice(ellipsisInsertionIndex, 0
/* num elements to delete */
, 1
/* element to add */
);
newStrides.pop();
}
}
return newStrides;
}
function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
if (normalizedAxis <= ellipsisInsertionIndex) {
return normalizedAxis;
}
return normalizedAxis - (numElidedAxes - 1);
}
function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
var elidedAxes = [];
for (var i = 0; i < numElidedAxes; i++) {
elidedAxes.push(ellipsisInsertionIndex + i);
}
return elidedAxes;
} // Normalize the start, end and strides.
function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
var inputRank = inputShape.length;
var normalizedBegin = new Array(inputRank),
normalizedEnd = new Array(inputRank),
normalizedStrides = new Array(inputRank);
if (ellipsisAxes.length && numInterpolatedAxes > 0) {
var fullIndex = ellipsisAxes[0]; // The ellipsis applies to the masked index as well as any dimensions
// that are interpolated.
var numElidedAxes = numInterpolatedAxes + 1;
normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
normalizedStrides = stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
} else {
for (var axis = 0; axis < inputRank; axis++) {
normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
normalizedEnd[axis] = stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
}
}
return {
begin: normalizedBegin,
end: normalizedEnd,
strides: normalizedStrides
};
} // Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current start value. Otherwise, insert.
function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
var newIndices = [].concat(inputShape);
var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (var axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = 0;
} else {
var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
var originalValue = originalBegin[originalAxis];
if (beginMask & 1 << originalAxis) {
originalValue = 0;
}
newIndices[axis] = originalValue;
}
}
return newIndices;
} // Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current stop value. Otherwise, insert.
function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
var newIndices = [].concat(inputShape);
var elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (var axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = Number.MAX_SAFE_INTEGER;
} else {
var originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
var originalValue = originalEnd[originalAxis];
if (endMask & 1 << originalAxis) {
originalValue = Number.MAX_SAFE_INTEGER;
}
newIndices[axis] = originalValue;
}
}
for (var i = 0; i < newIndices.length; i++) {
// Handle negative indices
var axisSize = inputShape[i];
if (newIndices[i] < 0) {
newIndices[i] += axisSize;
}
newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
}
return newIndices;
}
function stridesForAxis(strides, axis, ellipsisMask) {
var stride = strides[axis];
if (ellipsisMask & 1 << axis || stride == null) {
stride = 1;
}
return stride;
}
function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
// Begin with the specified index
var start = startIndices[axis];
var stride = strides[axis] || 1; // Check the axis bit from right of masked axes, or the begin index is not set
// for the axis.
if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
if (stride > 0) {
// Forward iteration - use the first element. These values will get
// clamped below (Note: We could have set them to 0 and axis_size-1, but
// use lowest() and max() to maintain symmetry with StopForAxis())
start = Number.MIN_SAFE_INTEGER;
} else {
// Backward iteration - use the last element.
start = Number.MAX_SAFE_INTEGER;
}
} // Handle negative indices
var axisSize = inputShape[axis];
if (start < 0) {
start += axisSize;
} // Clamping
start = clamp(0, start, axisSize - 1);
return start;
}
function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
// Begin with the specified index
var stop = stopIndices[axis];
var stride = strides[axis] || 1; // Check the axis bit from right of masked axes, or if the stop index is not
// set for this axis.
if (endMask & 1 << axis || ellipsisMask & 1 << axis || stop == null) {
if (stride > 0) {
// Forward iteration - use the last element. These values will get
// clamped below
stop = Number.MAX_SAFE_INTEGER;
} else {
// Backward iteration - use the first element.
stop = Number.MIN_SAFE_INTEGER;
}
} // Handle negative indices
var axisSize = inputShape[axis];
if (stop < 0) {
stop += axisSize;
} // Clamping
// Because the end index points one past the last element, we need slightly
// different clamping ranges depending on the direction.
if (stride > 0) {
// Forward iteration
stop = clamp(0, stop, axisSize);
} else {
// Backward iteration
stop = clamp(-1, stop, axisSize - 1);
}
return stop;
}
/**
* Returns true if the slice occupies a continous set of elements in the
* 'flat' space.
*/
function isSliceContinous(shape, begin, size) {
// Index of the first axis that has size > 1.
var firstNonOneAxis = size.length;
for (var i = 0; i < size.length; i++) {
if (size[i] > 1) {
firstNonOneAxis = i;
break;
}
}
for (var _i2 = firstNonOneAxis + 1; _i2 < size.length; _i2++) {
if (begin[_i2] > 0 || size[_i2] !== shape[_i2]) {
return false;
}
}
return true;
}
function computeFlatOffset(begin, strides) {
var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
for (var i = 0; i < begin.length - 1; i++) {
flatOffset += begin[i] * strides[i];
}
return flatOffset;
}
function parseSliceParams(x, begin, size) {
// The following logic allows for more ergonomic calls.
var begin_;
var xRank = x.shape.length;
if (typeof begin === 'number') {
begin_ = [begin].concat(new Array(xRank - 1).fill(0));
} else if (begin.length < xRank) {
begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
} else {
begin_ = begin.slice();
}
begin_.forEach(function (d) {
assert(d !== -1, function () {
return 'slice() does not support negative begin indexing.';
});
});
var size_;
if (size == null) {
size_ = new Array(xRank).fill(-1);
} else if (typeof size === 'number') {
size_ = [size].concat(new Array(xRank - 1).fill(-1));
} else if (size.length < xRank) {
size_ = size.concat(new Array(xRank - size.length).fill(-1));
} else {
size_ = size;
}
size_ = size_.map(function (d, i) {
if (d >= 0) {
return d;
} else {
assert(d === -1, function () {
return "Negative size values should be exactly -1 but got " + (d + " for the slice() size at index " + i + ".");
});
return x.shape[i] - begin_[i];
}
});
return [begin_, size_];
}
function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
// make a copy because it may be modified further down.
var $begin = begin.slice();
var $end = end.slice();
var $strides = strides;
if (strides == null) {
$strides = new Array($begin.length);
}
var ellipsisAxes = maskToAxes(ellipsisMask);
if (ellipsisAxes.length > 1) {
throw new Error('Multiple ellipses in slice is not allowed.');
}
if (ellipsisMask !== 0 && newAxisMask !== 0) {
throw new Error('Using both ellipsisMask and newAxisMask is not yet supported.');
}
if (ellipsisMask !== 0 && shrinkAxisMask !== 0) {
throw new Error('Using both ellipsisMask and shrinkAxisMask is not yet supported.');
}
var numInterpolatedAxes = xShape.length - $begin.length; // Expand the dims of x based on the newAxisMask.
var expandAxes = maskToAxes(newAxisMask);
var newShape = xShape.slice();
expandAxes.forEach(function (axis) {
$begin[axis] = 0;
$end[axis] = 1;
newShape.splice(axis, 0, 1);
});
var _getNormalizedAxes = getNormalizedAxes(newShape, ellipsisAxes, numInterpolatedAxes, $begin, $end, $strides, beginMask, endMask, ellipsisMask),
normalizedBegin = _getNormalizedAxes.begin,
normalizedEnd = _getNormalizedAxes.end,
normalizedStrides = _getNormalizedAxes.strides;
$begin = normalizedBegin;
$end = normalizedEnd;
$strides = normalizedStrides;
var shrinkAxes = maskToAxes(shrinkAxisMask); // Adjust the ends based on the shrink mask.
shrinkAxes.forEach(function (axis) {
$end[axis] = $begin[axis] + 1;
$strides[axis] = 1;
}); // Figure out the output shape.
var size = computeOutShape($begin, $end, $strides); // Remove the axes based on shrinkMask.
var outShape = size.filter(function (_, axis) {
return shrinkAxes.indexOf(axis) === -1;
});
var nonStrided = $strides.every(function (v) {
return v === 1;
});
return {
nonStrided: nonStrided,
$begin: $begin,
$end: $end,
$strides: $strides,
size: size,
newShape: newShape,
outShape: outShape
};
}
var slice_util = {
__proto__: null,
assertParamsValid: assertParamsValid,
maskToAxes: maskToAxes,
computeOutShape: computeOutShape,
stridesWithElidedDims: stridesWithElidedDims,
getNormalizedAxes: getNormalizedAxes,
startIndicesWithElidedDims: startIndicesWithElidedDims,
stopIndicesWithElidedDims: stopIndicesWithElidedDims,
stridesForAxis: stridesForAxis,
startForAxis: startForAxis,
stopForAxis: stopForAxis,
isSliceContinous: isSliceContinous,
computeFlatOffset: computeFlatOffset,
parseSliceParams: parseSliceParams,
sliceInfo: sliceInfo
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Serializable defines the serialization contract.
*
* TFJS requires serializable classes to return their className when asked
* to avoid issues with minification.
*/
var Serializable = /*#__PURE__*/function () {
function Serializable() {}
var _proto = Serializable.prototype;
/**
* Return the class name for this class to use in serialization contexts.
*
* Generally speaking this will be the same thing that constructor.name
* would have returned. However, the class name needs to be robust
* against minification for serialization/deserialization to work properly.
*
* There's also places such as initializers.VarianceScaling, where
* implementation details between different languages led to different
* class hierarchies and a non-leaf node is used for serialization purposes.
*/
_proto.getClassName = function getClassName() {
return this.constructor.className;
}
/**
* Creates an instance of T from a ConfigDict.
*
* This works for most descendants of serializable. A few need to
* provide special handling.
* @param cls A Constructor for the class to instantiate.
* @param config The Configuration for the object.
*/
/** @nocollapse */
;
Serializable.fromConfig = function fromConfig(cls, config) {
return new cls(config);
};
return Serializable;
}();
/**
* Maps string keys to class constructors.
*
* Used during (de)serialization from the cross-language JSON format, which
* requires the class name in the serialization format matches the class
* names as used in Python, should it exist.
*/
var SerializationMap = /*#__PURE__*/function () {
function SerializationMap() {
this.classNameMap = {};
}
/**
* Returns the singleton instance of the map.
*/
SerializationMap.getMap = function getMap() {
if (SerializationMap.instance == null) {
SerializationMap.instance = new SerializationMap();
}
return SerializationMap.instance;
}
/**
* Registers the class as serializable.
*/
;
SerializationMap.register = function register(cls) {
SerializationMap.getMap().classNameMap[cls.className] = [cls, cls.fromConfig];
};
return SerializationMap;
}();
/**
* Register a class with the serialization map of TensorFlow.js.
*
* This is often used for registering custom Layers, so they can be
* serialized and deserialized.
*
* Example:
*
* ```js
* class MyCustomLayer extends tf.layers.Layer {
* static className = 'MyCustomLayer';
*
* constructor(config) {
* super(config);
* }
* }
* tf.serialization.registerClass(MyCustomLayer);
* ```
*
* @param cls The class to be registered. It must have a public static member
* called `className` defined and the value must be a non-empty string.
*
* @doc {heading: 'Models', subheading: 'Serialization', ignoreCI: true}
*/
function registerClass(cls) {
assert(cls.className != null, function () {
return "Class being registered does not have the static className " + "property defined.";
});
assert(typeof cls.className === 'string', function () {
return "className is required to be a string, but got type " + typeof cls.className;
});
assert(cls.className.length > 0, function () {
return "Class being registered has an empty-string as its className, " + "which is disallowed.";
});
SerializationMap.register(cls);
}
var serialization = {
__proto__: null,
Serializable: Serializable,
SerializationMap: SerializationMap,
registerClass: registerClass
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TEST_EPSILON_FLOAT32 = 1e-3;
var TEST_EPSILON_FLOAT16 = 1e-1;
function expectArraysClose(actual, expected, epsilon) {
if (epsilon == null) {
epsilon = testEpsilon();
}
return expectArraysPredicate(actual, expected, function (a, b) {
return areClose(a, b, epsilon);
});
}
function testEpsilon() {
return ENGINE.backend.floatPrecision() === 32 ? TEST_EPSILON_FLOAT32 : TEST_EPSILON_FLOAT16;
}
function expectArraysPredicate(actual, expected, predicate) {
var checkClassType = true;
if (isTypedArray$1(actual) || isTypedArray$1(expected)) {
checkClassType = false;
}
if (isTypedArray$1(actual) && isTypedArray$1(expected)) {
checkClassType = true;
}
if (checkClassType) {
var aType = actual.constructor.name;
var bType = expected.constructor.name;
if (aType !== bType) {
throw new Error("Arrays are of different type. Actual: " + aType + ". " + ("Expected: " + bType));
}
}
if (Array.isArray(actual) && Array.isArray(expected)) {
var actualShape = inferShape(actual);
var expectedShape = inferShape(expected);
if (!arraysEqual(actualShape, expectedShape)) {
throw new Error("Arrays have different shapes. " + ("Actual: [" + actualShape + "]. Expected: [" + expectedShape + "]"));
}
}
var actualFlat = isTypedArray$1(actual) ? actual : flatten(actual);
var expectedFlat = isTypedArray$1(expected) ? expected : flatten(expected);
if (actualFlat.length !== expectedFlat.length) {
throw new Error("Arrays have different lengths actual: " + actualFlat.length + " vs " + ("expected: " + expectedFlat.length + ".\n") + ("Actual: " + actualFlat + ".\n") + ("Expected: " + expectedFlat + "."));
}
for (var i = 0; i < expectedFlat.length; ++i) {
var a = actualFlat[i];
var e = expectedFlat[i];
if (!predicate(a, e)) {
throw new Error("Arrays differ: actual[" + i + "] = " + a + ", expected[" + i + "] = " + e + ".\n" + ("Actual: " + actualFlat + ".\n") + ("Expected: " + expectedFlat + "."));
}
}
}
function expectPromiseToFail(fn, done) {
fn().then(function () {
return done.fail();
}, function () {
return done();
});
}
function expectArraysEqual(actual, expected) {
var exp = typeof expected === 'string' || typeof expected === 'number' || typeof expected === 'boolean' ? [expected] : expected;
if (isString(actual) || isString(actual[0]) || isString(expected) || isString(expected[0])) {
// tslint:disable-next-line: triple-equals
return expectArraysPredicate(actual, exp, function (a, b) {
return a == b;
});
}
return expectArraysPredicate(actual, expected, function (a, b) {
return areClose(a, b, 0);
});
}
function expectNumbersClose(a, e, epsilon) {
if (epsilon == null) {
epsilon = testEpsilon();
}
if (!areClose(a, e, epsilon)) {
throw new Error("Numbers differ: actual === " + a + ", expected === " + e);
}
}
function areClose(a, e, epsilon) {
if (!isFinite(a) && !isFinite(e)) {
return true;
}
if (isNaN(a) || isNaN(e) || Math.abs(a - e) > epsilon) {
return false;
}
return true;
}
function expectValuesInRange(actual, low, high) {
for (var i = 0; i < actual.length; i++) {
if (actual[i] < low || actual[i] > high) {
throw new Error("Value out of range:" + actual[i] + " low: " + low + ", high: " + high);
}
}
}
function expectArrayBuffersEqual(actual, expected) {
// Safari & Jasmine don't like comparing ArrayBuffers directly. Wrapping in
// a Float32Array solves this issue.
expect(new Float32Array(actual)).toEqual(new Float32Array(expected));
}
/** Encodes strings into utf-8 bytes. */
function encodeStrings(a) {
for (var i = 0; i < a.length; i++) {
var val = a[i];
if (Array.isArray(val)) {
encodeStrings(val);
} else {
a[i] = encodeString(val);
}
}
return a;
}
var test_util = {
__proto__: null,
TEST_EPSILON_FLOAT16: TEST_EPSILON_FLOAT16,
expectArraysClose: expectArraysClose,
testEpsilon: testEpsilon,
expectPromiseToFail: expectPromiseToFail,
expectArraysEqual: expectArraysEqual,
expectNumbersClose: expectNumbersClose,
expectValuesInRange: expectValuesInRange,
expectArrayBuffersEqual: expectArrayBuffersEqual,
encodeStrings: encodeStrings
};
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$1 = '3.9.0';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Enables production mode which disables correctness checks in favor of
* performance.
*
* @doc {heading: 'Environment'}
*/
function enableProdMode() {
env().set('PROD', true);
}
/**
* Enables debug mode which will log information about all executed kernels:
* the elapsed time of the kernel execution, as well as the rank, shape, and
* size of the output tensor.
*
* Debug mode will significantly slow down your application as it will
* download the result of every operation to the CPU. This should not be used in
* production. Debug mode does not affect the timing information of the kernel
* execution as we do not measure download time in the kernel execution time.
*
* See also: `tf.profile`, `tf.memory`.
*
* @doc {heading: 'Environment'}
*/
function enableDebugMode() {
env().set('DEBUG', true);
}
/** Globally disables deprecation warnings */
function disableDeprecationWarnings() {
env().set('DEPRECATION_WARNINGS_ENABLED', false);
console.warn("TensorFlow.js deprecation warnings have been disabled.");
}
/** Warn users about deprecated functionality. */
function deprecationWarn(msg) {
if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) {
console.warn(msg + ' You can disable deprecation warnings with ' + 'tf.disableDeprecationWarnings().');
}
}
setDeprecationWarningFn(deprecationWarn);
/**
* Dispose all variables kept in backend engine.
*
* @doc {heading: 'Environment'}
*/
function disposeVariables() {
ENGINE.disposeVariables();
}
/**
* It returns the global engine that keeps track of all tensors and backends.
*
* @doc {heading: 'Environment'}
*/
function engine() {
return ENGINE;
}
/**
* Returns memory info at the current time in the program. The result is an
* object with the following properties:
*
* - `numBytes`: Number of bytes allocated (undisposed) at this time.
* - `numTensors`: Number of unique tensors allocated.
* - `numDataBuffers`: Number of unique data buffers allocated
* (undisposed) at this time, which is ≤ the number of tensors
* (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same
* data buffer with `a`).
* - `unreliable`: True if the memory usage is unreliable. See `reasons` when
* `unreliable` is true.
* - `reasons`: `string[]`, reasons why the memory is unreliable, present if
* `unreliable` is true.
*
* WebGL Properties:
* - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at
* this time.
*
* @doc {heading: 'Performance', subheading: 'Memory'}
*/
function memory() {
return ENGINE.memory();
}
/**
* Executes the provided function `f()` and returns a promise that resolves
* with information about the function's memory use:
* - `newBytes`: the number of new bytes allocated
* - `newTensors`: the number of new tensors created
* - `peakBytes`: the peak number of bytes allocated
* - `kernels`: an array of objects for each kernel involved that reports
* their input and output shapes, number of bytes used, and number of new
* tensors created.
* - `kernelNames`: an array of unique strings with just the names of the
* kernels in the `kernels` array.
*
* ```js
* const profile = await tf.profile(() => {
* const x = tf.tensor1d([1, 2, 3]);
* let x2 = x.square();
* x2.dispose();
* x2 = x.square();
* x2.dispose();
* return x;
* });
*
* console.log(`newBytes: ${profile.newBytes}`);
* console.log(`newTensors: ${profile.newTensors}`);
* console.log(`byte usage over all kernels: ${profile.kernels.map(k =>
* k.totalBytesSnapshot)}`);
* ```
*
*
* @doc {heading: 'Performance', subheading: 'Profile'}
*/
function profile(f) {
return ENGINE.profile(f);
}
/**
* Executes the provided function `fn` and after it is executed, cleans up all
* intermediate tensors allocated by `fn` except those returned by `fn`.
* `fn` must not return a Promise (async functions not allowed). The returned
* result can be a complex object.
*
* Using this method helps avoid memory leaks. In general, wrap calls to
* operations in `tf.tidy` for automatic memory cleanup.
*
* NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to
* dispose variables, please use `tf.disposeVariables` or call dispose()
* directly on variables.
*
* ```js
* // y = 2 ^ 2 + 1
* const y = tf.tidy(() => {
* // a, b, and one will be cleaned up when the tidy ends.
* const one = tf.scalar(1);
* const a = tf.scalar(2);
* const b = a.square();
*
* console.log('numTensors (in tidy): ' + tf.memory().numTensors);
*
* // The value returned inside the tidy function will return
* // through the tidy, in this case to the variable y.
* return b.add(one);
* });
*
* console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
* y.print();
* ```
*
* @param nameOrFn The name of the closure, or the function to execute.
* If a name is provided, the 2nd argument should be the function.
* If debug mode is on, the timing and the memory usage of the function
* will be tracked and displayed on the console using the provided name.
* @param fn The function to execute.
*
* @doc {heading: 'Performance', subheading: 'Memory'}
*/
function tidy(nameOrFn, fn) {
return ENGINE.tidy(nameOrFn, fn);
}
/**
* Disposes any `tf.Tensor`s found within the provided object.
*
* @param container an object that may be a `tf.Tensor` or may directly
* contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If
* the object is not a `tf.Tensor` or does not contain `Tensors`, nothing
* happens. In general it is safe to pass any object here, except that
* `Promise`s are not supported.
*
* @doc {heading: 'Performance', subheading: 'Memory'}
*/
function dispose(container) {
var tensors = getTensorsInContainer(container);
tensors.forEach(function (tensor) {
return tensor.dispose();
});
}
/**
* Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
* automatically.
*
* ```js
* let b;
* const y = tf.tidy(() => {
* const one = tf.scalar(1);
* const a = tf.scalar(2);
*
* // b will not be cleaned up by the tidy. a and one will be cleaned up
* // when the tidy ends.
* b = tf.keep(a.square());
*
* console.log('numTensors (in tidy): ' + tf.memory().numTensors);
*
* // The value returned inside the tidy function will return
* // through the tidy, in this case to the variable y.
* return b.add(one);
* });
*
* console.log('numTensors (outside tidy): ' + tf.memory().numTensors);
* console.log('y:');
* y.print();
* console.log('b:');
* b.print();
* ```
*
* @param result The tensor to keep from being disposed.
*
* @doc {heading: 'Performance', subheading: 'Memory'}
*/
function keep(result) {
return ENGINE.keep(result);
}
/**
* Executes `f()` and returns a promise that resolves with timing
* information.
*
* The result is an object with the following properties:
*
* - `wallMs`: Wall execution time.
* - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
* WebGL backend and the query timer extension is not available, this will
* return an error object.
* - On `WebGL` The following additional properties exist:
* - `uploadWaitMs`: CPU blocking time on texture uploads.
* - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
*
* ```js
* const x = tf.randomNormal([20, 20]);
* const time = await tf.time(() => x.matMul(x));
*
* console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`);
* ```
*
* @param f The function to execute and time.
*
* @doc {heading: 'Performance', subheading: 'Timing'}
*/
function time(f) {
return ENGINE.time(f);
}
/**
* Sets the backend (cpu, webgl, wasm, etc) responsible for creating tensors and
* executing operations on those tensors. Returns a promise that resolves
* to a boolean if the backend initialization was successful.
*
* Note this disposes the current backend, if any, as well as any tensors
* associated with it. A new backend is initialized, even if it is of the
* same type as the previous one.
*
* @param backendName The name of the backend. Currently supports
* `'webgl'|'cpu'` in the browser, `'tensorflow'` under node.js
* (requires tfjs-node), and `'wasm'` (requires tfjs-backend-wasm).
*
* @doc {heading: 'Backends'}
*/
function setBackend(backendName) {
return ENGINE.setBackend(backendName);
}
/**
* Returns a promise that resolves when the currently selected backend (or the
* highest priority one) has initialized. Await this promise when you are using
* a backend that has async initialization.
*
* @doc {heading: 'Backends'}
*/
function ready() {
return ENGINE.ready();
}
/**
* Returns the current backend name (cpu, webgl, etc). The backend is
* responsible for creating tensors and executing operations on those tensors.
*
* @doc {heading: 'Backends'}
*/
function getBackend() {
return ENGINE.backendName;
}
/**
* Removes a backend and the registered factory.
*
* @doc {heading: 'Backends'}
*/
function removeBackend(name) {
ENGINE.removeBackend(name);
}
/**
* Finds the backend registered under the provided name. Returns null if the
* name is not in the registry, or the registration hasn't finished yet.
*/
function findBackend(name) {
return ENGINE.findBackend(name);
}
/**
* Finds the backend factory registered under the provided name. Returns a
* function that produces a new backend when called. Returns null if the name
* is not in the registry.
*/
function findBackendFactory(name) {
return ENGINE.findBackendFactory(name);
}
/**
* Registers a global backend. The registration should happen when importing
* a module file (e.g. when importing `backend_webgl.ts`), and is used for
* modular builds (e.g. custom tfjs bundle with only webgl support).
*
* @param factory The backend factory function. When called, it should
* return a backend instance, or a promise of an instance.
* @param priority The priority of the backend (higher = more important).
* In case multiple backends are registered, the priority is used to find
* the best backend. Defaults to 1.
* @return False if there is already a registered backend under this name, true
* if not.
*
* @doc {heading: 'Backends'}
*/
function registerBackend(name, factory, priority) {
if (priority === void 0) {
priority = 1;
}
return ENGINE.registerBackend(name, factory, priority);
}
/**
* Gets the current backend. If no backends have been initialized, this will
* attempt to initialize the best backend. Will throw an error if the highest
* priority backend has async initialization, in which case, you should call
* 'await tf.ready()' before running other code.
*
* @doc {heading: 'Backends'}
*/
function backend() {
return ENGINE.backend;
}
/**
* Sets the global platform.
*
* @param platformName The name of this platform.
* @param platform A platform implementation.
*/
function setPlatform(platformName, platform) {
env().setPlatform(platformName, platform);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting.
*
*
* ```js
* const a = tf.tensor1d([1, 2, 3, 4]);
* const b = tf.tensor1d([10, 20, 30, 40]);
*
* a.add(b).print(); // or tf.add(a, b)
* ```
*
* ```js
* // Broadcast add a with b.
* const a = tf.scalar(5);
* const b = tf.tensor1d([10, 20, 30, 40]);
*
* a.add(b).print(); // or tf.add(a, b)
* ```
* @param a The first `tf.Tensor` to add.
* @param b The second `tf.Tensor` to add. Must have the same type as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function add_(a, b) {
var $a = convertToTensor(a, 'a', 'add');
var $b = convertToTensor(b, 'b', 'add');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Add, inputs);
}
var add$1 = op({
add_: add_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
* The result is rounded with floor function.
*
*
* ```js
* const a = tf.tensor1d([1, 4, 9, 16]);
* const b = tf.tensor1d([1, 2, 3, 4]);
*
* a.floorDiv(b).print(); // or tf.div(a, b)
* ```
*
* ```js
* // Broadcast div a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(2);
*
* a.floorDiv(b).print(); // or tf.floorDiv(a, b)
* ```
*
* @param a The first tensor as the numerator.
* @param b The second tensor as the denominator. Must have the same dtype as
* `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function floorDiv_(a, b) {
var $a = convertToTensor(a, 'a', 'floorDiv');
var $b = convertToTensor(b, 'b', 'floorDiv');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(FloorDiv, inputs);
}
var floorDiv = op({
floorDiv_: floorDiv_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 4, 9, 16]);
* const b = tf.tensor1d([1, 2, 3, 4]);
*
* a.div(b).print(); // or tf.div(a, b)
* ```
*
* ```js
* // Broadcast div a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(2);
*
* a.div(b).print(); // or tf.div(a, b)
* ```
*
* @param a The first tensor as the numerator.
* @param b The second tensor as the denominator. Must have the same dtype as
* `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function div_(a, b) {
var $a = convertToTensor(a, 'a', 'div');
var $b = convertToTensor(b, 'b', 'div');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
return floorDiv($a, $b);
}
var inputs = {
a: $a,
b: $b
};
var attrs = {}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(RealDiv, inputs, attrs);
}
var div = op({
div_: div_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting.
*
* We also expose `tf.mulStrict` which has the same signature as this op and
* asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 2, 3, 4]);
* const b = tf.tensor1d([2, 3, 4, 5]);
*
* a.mul(b).print(); // or tf.mul(a, b)
* ```
*
* ```js
* // Broadcast mul a with b.
* const a = tf.tensor1d([1, 2, 3, 4]);
* const b = tf.scalar(5);
*
* a.mul(b).print(); // or tf.mul(a, b)
* ```
* @param a The first tensor to multiply.
* @param b The second tensor to multiply. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function mul_(a, b) {
var $a = convertToTensor(a, 'a', 'mul');
var $b = convertToTensor(b, 'b', 'mul');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Multiply, inputs);
}
var mul = op({
mul_: mul_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes absolute value element-wise: `abs(x)`
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
*
* x.abs().print(); // or tf.abs(x)
* ```
* @param x The input `tf.Tensor`.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function abs_(x) {
var $x = convertToTensor(x, 'x', 'abs');
if ($x.dtype === 'complex64') {
var inputs = {
x: $x
};
return ENGINE.runKernel(ComplexAbs, inputs);
} else {
var _inputs = {
x: $x
};
return ENGINE.runKernel(Abs, _inputs);
}
}
var abs$8 = op({
abs_: abs_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes acos of the input `tf.Tensor` element-wise: `acos(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.acos().print(); // or tf.acos(x)
* ```
* @param x The input tensor.
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function acos_(x) {
var $x = convertToTensor(x, 'x', 'acos');
var inputs = {
x: $x
};
return ENGINE.runKernel(Acos, inputs);
}
var acos = op({
acos_: acos_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise:
* `acosh(x)`
*
* ```js
* const x = tf.tensor1d([10, 1, 3, 5.7]);
*
* x.acosh().print(); // or tf.acosh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function acosh_(x) {
var $x = convertToTensor(x, 'x', 'acosh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Acosh, inputs);
}
var acosh = op({
acosh_: acosh_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype.
*
* ```js
* const a = tf.tensor1d([1, 2]);
* const b = tf.tensor1d([3, 4]);
* const c = tf.tensor1d([5, 6]);
*
* tf.addN([a, b, c]).print();
* ```
* @param tensors A list of tensors with the same shape and dtype.
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function addN_(tensors) {
assert(Array.isArray(tensors), function () {
return 'The argument passed to tf.addN() must be a list of tensors';
});
assert(tensors.length >= 1, function () {
return "Must pass at least one tensor to tf.addN(), but got " + ("" + tensors.length);
});
var $tensors = tensors.map(function (t, i) {
return convertToTensor(t, "tensors" + i, 'addN');
});
var firstTensor = $tensors[0];
$tensors.forEach(function (t) {
if (t.dtype !== firstTensor.dtype) {
throw new Error('All tensors passed to tf.addN() must have the same dtype');
}
});
$tensors.forEach(function (t) {
if (!arraysEqual(t.shape, firstTensor.shape)) {
throw new Error('All tensors passed to tf.addN() must have the same shape');
}
});
var inputs = $tensors;
return ENGINE.runKernel(AddN, inputs);
}
var addN = op({
addN_: addN_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the logical and of elements across dimensions of a `tf.Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
* `axes`. If `keepDims` is true, the reduced dimensions are retained with
* length 1. If `axes` has no entries, all dimensions are reduced, and an
* `tf.Tensor` with a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 1, 1], 'bool');
*
* x.all().print(); // or tf.all(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
*
* const axis = 1;
* x.all(axis).print(); // or tf.all(x, axis)
* ```
*
* @param x The input tensor. Must be of dtype bool.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function all_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'all', 'bool');
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
};
return ENGINE.runKernel(All, inputs, attrs);
}
var all = op({
all_: all_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the logical or of elements across dimensions of a `tf.Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
* `axes`. If `keepDims` is true, the reduced dimensions are retained with
* length 1. If `axes` has no entries, all dimensions are reduced, and an
* `tf.Tensor` with a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 1, 1], 'bool');
*
* x.any().print(); // or tf.any(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool');
*
* const axis = 1;
* x.any(axis).print(); // or tf.any(x, axis)
* ```
*
* @param x The input tensor. Must be of dtype bool.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function any_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'any', 'bool');
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
};
return ENGINE.runKernel(Any, inputs, attrs);
} // tslint:disable-next-line:variable-name
var any = op({
any_: any_
});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the indices of the maximum values along an `axis`.
*
* The result has the same shape as `input` with the dimension along `axis`
* removed.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.argMax().print(); // or tf.argMax(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
*
* const axis = 1;
* x.argMax(axis).print(); // or tf.argMax(x, axis)
* ```
*
* @param x The input tensor.
* @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function argMax_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'argMax');
var inputs = {
x: $x
};
var attrs = {
axis: axis
};
return ENGINE.runKernel(ArgMax, inputs, attrs);
}
var argMax = op({
argMax_: argMax_
});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the indices of the minimum values along an `axis`.
*
* The result has the same shape as `input` with the dimension along `axis`
* removed.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.argMin().print(); // or tf.argMin(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 4, 3], [2, 2]);
*
* const axis = 1;
* x.argMin(axis).print(); // or tf.argMin(x, axis)
* ```
*
* @param x The input tensor.
* @param axis The dimension to reduce. Defaults to 0 (outer-most dimension).
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function argMin_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'argMin');
var inputs = {
x: $x
};
var attrs = {
axis: axis
};
return ENGINE.runKernel(ArgMin, inputs, attrs);
}
var argMin = op({
argMin_: argMin_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes asin of the input `tf.Tensor` element-wise: `asin(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.asin().print(); // or tf.asin(x)
* ```
* @param x The input tensor.
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function asin_(x) {
var $x = convertToTensor(x, 'x', 'asin');
var inputs = {
x: $x
};
return ENGINE.runKernel(Asin, inputs);
}
var asin = op({
asin_: asin_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise:
* `asinh(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.asinh().print(); // or tf.asinh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function asinh_(x) {
var $x = convertToTensor(x, 'x', 'asinh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Asinh, inputs);
}
var asinh$1 = op({
asinh_: asinh_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes atan of the input `tf.Tensor` element-wise: `atan(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.atan().print(); // or tf.atan(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function atan_(x) {
var $x = convertToTensor(x, 'x', 'atan');
var inputs = {
x: $x
};
return ENGINE.runKernel(Atan, inputs);
}
var atan = op({
atan_: atan_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`.
* Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1.0, 1.0, -1.0, .7]);
* const b = tf.tensor1d([2.0, 13.0, 3.5, .21]);
*
* tf.atan2(a, b).print()
* ```
*
* @param a The first tensor.
* @param b The second tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function atan2_(a, b) {
var $a = convertToTensor(a, 'a', 'atan2');
var $b = convertToTensor(b, 'b', 'atan2');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Atan2, inputs);
}
var atan2 = op({
atan2_: atan2_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise:
* `atanh(x)`
*
* ```js
* const x = tf.tensor1d([0, .1, -.1, .7]);
*
* x.atanh().print(); // or tf.atanh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function atanh_(x) {
var $x = convertToTensor(x, 'x', 'atanh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Atanh, inputs);
}
var atanh = op({
atanh_: atanh_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
*
* @param inputShape Input tensor shape is of the following dimensions:
* `[batch, height, width, inChannels]`.
* @param filterShape The filter shape is of the following dimensions:
* `[filterHeight, filterWidth, depth]`.
* @param strides The strides of the sliding window for each dimension of the
* input tensor: `[strideHeight, strideWidth]`.
* If `strides` is a single number,
* then `strideHeight == strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1*1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat The data format of the input and output data.
* Defaults to 'NHWC'.
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`.
* Defaults to `[1, 1]`. If `dilations` is a single number, then
* `dilationHeight == dilationWidth`.
*/
function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat, dilations) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
// `computerConv2DInfo` require filterShape to be in the dimension of:
// `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have
// outDepth, it should have the same depth as the input.
// Input shape: [batch, height, width, inChannels]
var inputChannels = inputShape[3];
var $filterShape = [].concat(filterShape, [inputChannels]);
var $dataFormat = convertConv2DDataFormat(dataFormat);
return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null
/* roundingMode */
, null
/* depthWise */
, $dataFormat);
}
function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
if (dataFormat === void 0) {
dataFormat = 'channelsLast';
}
var _parseTupleParam = parseTupleParam(filterSize),
filterHeight = _parseTupleParam[0],
filterWidth = _parseTupleParam[1];
var filterShape;
if (dataFormat === 'channelsLast') {
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
} else if (dataFormat === 'channelsFirst') {
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
}
/**
* Computes the information for a forward pass of a pooling3D operation.
*/
function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) {
if (dataFormat === void 0) {
dataFormat = 'NDHWC';
}
var _parse3TupleParam = parse3TupleParam(filterSize),
filterDepth = _parse3TupleParam[0],
filterHeight = _parse3TupleParam[1],
filterWidth = _parse3TupleParam[2];
var filterShape;
var $dataFormat;
if (dataFormat === 'NDHWC') {
$dataFormat = 'channelsLast';
filterShape = [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
} else if (dataFormat === 'NCDHW') {
$dataFormat = 'channelsFirst';
filterShape = [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
}
/**
* Computes the information for a forward pass of a convolution/pooling
* operation.
*/
function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) {
if (depthwise === void 0) {
depthwise = false;
}
if (dataFormat === void 0) {
dataFormat = 'channelsLast';
}
var batchSize = -1,
inHeight = -1,
inWidth = -1,
inChannels = -1;
if (dataFormat === 'channelsLast') {
batchSize = inShape[0];
inHeight = inShape[1];
inWidth = inShape[2];
inChannels = inShape[3];
} else if (dataFormat === 'channelsFirst') {
batchSize = inShape[0];
inChannels = inShape[1];
inHeight = inShape[2];
inWidth = inShape[3];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterHeight = filterShape[0],
filterWidth = filterShape[1],
filterChannels = filterShape[3];
var _parseTupleParam2 = parseTupleParam(strides),
strideHeight = _parseTupleParam2[0],
strideWidth = _parseTupleParam2[1];
var _parseTupleParam3 = parseTupleParam(dilations),
dilationHeight = _parseTupleParam3[0],
dilationWidth = _parseTupleParam3[1];
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _getPadAndOutInfo = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat),
padInfo = _getPadAndOutInfo.padInfo,
outHeight = _getPadAndOutInfo.outHeight,
outWidth = _getPadAndOutInfo.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outHeight, outWidth];
} else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outHeight, outWidth, outChannels];
}
return {
batchSize: batchSize,
dataFormat: dataFormat,
inHeight: inHeight,
inWidth: inWidth,
inChannels: inChannels,
outHeight: outHeight,
outWidth: outWidth,
outChannels: outChannels,
padInfo: padInfo,
strideHeight: strideHeight,
strideWidth: strideWidth,
filterHeight: filterHeight,
filterWidth: filterWidth,
effectiveFilterHeight: effectiveFilterHeight,
effectiveFilterWidth: effectiveFilterWidth,
dilationHeight: dilationHeight,
dilationWidth: dilationWidth,
inShape: inShape,
outShape: outShape,
filterShape: filterShape
};
}
/**
* Computes the information for a forward pass of a 3D convolution/pooling
* operation.
*/
function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) {
if (depthwise === void 0) {
depthwise = false;
}
if (dataFormat === void 0) {
dataFormat = 'channelsLast';
}
var batchSize = -1,
inDepth = -1,
inHeight = -1,
inWidth = -1,
inChannels = -1;
if (dataFormat === 'channelsLast') {
batchSize = inShape[0];
inDepth = inShape[1];
inHeight = inShape[2];
inWidth = inShape[3];
inChannels = inShape[4];
} else if (dataFormat === 'channelsFirst') {
batchSize = inShape[0];
inChannels = inShape[1];
inDepth = inShape[2];
inHeight = inShape[3];
inWidth = inShape[4];
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
var filterDepth = filterShape[0],
filterHeight = filterShape[1],
filterWidth = filterShape[2],
filterChannels = filterShape[4];
var _parse3TupleParam2 = parse3TupleParam(strides),
strideDepth = _parse3TupleParam2[0],
strideHeight = _parse3TupleParam2[1],
strideWidth = _parse3TupleParam2[2];
var _parse3TupleParam3 = parse3TupleParam(dilations),
dilationDepth = _parse3TupleParam3[0],
dilationHeight = _parse3TupleParam3[1],
dilationWidth = _parse3TupleParam3[2];
var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
var _get3DPadAndOutInfo = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode),
padInfo = _get3DPadAndOutInfo.padInfo,
outDepth = _get3DPadAndOutInfo.outDepth,
outHeight = _get3DPadAndOutInfo.outHeight,
outWidth = _get3DPadAndOutInfo.outWidth;
var outChannels = depthwise ? filterChannels * inChannels : filterChannels;
var outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
} else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize: batchSize,
dataFormat: dataFormat,
inDepth: inDepth,
inHeight: inHeight,
inWidth: inWidth,
inChannels: inChannels,
outDepth: outDepth,
outHeight: outHeight,
outWidth: outWidth,
outChannels: outChannels,
padInfo: padInfo,
strideDepth: strideDepth,
strideHeight: strideHeight,
strideWidth: strideWidth,
filterDepth: filterDepth,
filterHeight: filterHeight,
filterWidth: filterWidth,
effectiveFilterDepth: effectiveFilterDepth,
effectiveFilterHeight: effectiveFilterHeight,
effectiveFilterWidth: effectiveFilterWidth,
dilationDepth: dilationDepth,
dilationHeight: dilationHeight,
dilationWidth: dilationWidth,
inShape: inShape,
outShape: outShape,
filterShape: filterShape
};
}
function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputRows = inShape[0];
var inputCols = inShape[1];
var outputRows = round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
var outputCols = round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
return [outputRows, outputCols];
}
function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
var inputDepth = inShape[0];
var inputRows = inShape[1];
var inputCols = inShape[2];
var outputDepths = round((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
var outputRows = round((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
var outputCols = round((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
return [outputDepths, outputRows, outputCols, outChannels];
}
function computeDefaultPad(inputShape, fieldSize, stride, dilation) {
if (dilation === void 0) {
dilation = 1;
}
var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
}
function parseTupleParam(param) {
if (typeof param === 'number') {
return [param, param, param];
}
if (param.length === 2) {
return [param[0], param[1], 1];
}
return param;
}
function parse3TupleParam(param) {
return typeof param === 'number' ? [param, param, param] : param;
}
/* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d
* Atrous convolution is equivalent to standard convolution with upsampled
* filters with effective_filter_height =
* filter_height + (filter_height - 1) * (dilation - 1)
* and effective_filter_width =
* filter_width + (filter_width - 1) * (dilation - 1),
* produced by inserting dilation - 1 zeros along consecutive elements across
* the filters' spatial dimensions.
* When there is a dilation, this converts a filter dimension to the
* effective filter dimension, so it can be used in a standard convolution.
*/
function getEffectiveFilterSize(filterSize, dilation) {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
}
function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
var padInfo;
var outHeight;
var outWidth;
if (typeof pad === 'number') {
var padType = pad === 0 ? 'VALID' : 'NUMBER';
padInfo = {
top: pad,
bottom: pad,
left: pad,
right: pad,
type: padType
};
var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
outHeight = outShape[0];
outWidth = outShape[1];
} else if (pad === 'same') {
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
var top = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = {
top: top,
bottom: bottom,
left: left,
right: right,
type: 'SAME'
};
} else if (pad === 'valid') {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
type: 'VALID'
};
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else if (typeof pad === 'object') {
var _top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
var _bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
var _left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
var _right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
var _padType = _top === 0 && _bottom === 0 && _left === 0 && _right === 0 ? 'VALID' : 'EXPLICIT';
padInfo = {
top: _top,
bottom: _bottom,
left: _left,
right: _right,
type: _padType
};
outHeight = round((inHeight - filterHeight + _top + _bottom) / strideHeight + 1, roundingMode);
outWidth = round((inWidth - filterWidth + _left + _right) / strideWidth + 1, roundingMode);
} else {
throw Error("Unknown padding parameter: " + pad);
}
return {
padInfo: padInfo,
outHeight: outHeight,
outWidth: outWidth
};
}
function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
var padInfo;
var outDepth;
var outHeight;
var outWidth;
if (typeof pad === 'number') {
var padType = pad === 0 ? 'VALID' : 'NUMBER';
padInfo = {
top: pad,
bottom: pad,
left: pad,
right: pad,
front: pad,
back: pad,
type: padType
};
var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
} else if (pad === 'same') {
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
var front = Math.floor(padAlongDepth / 2);
var back = padAlongDepth - front;
var top = Math.floor(padAlongHeight / 2);
var bottom = padAlongHeight - top;
var left = Math.floor(padAlongWidth / 2);
var right = padAlongWidth - left;
padInfo = {
top: top,
bottom: bottom,
left: left,
right: right,
front: front,
back: back,
type: 'SAME'
};
} else if (pad === 'valid') {
padInfo = {
top: 0,
bottom: 0,
left: 0,
right: 0,
front: 0,
back: 0,
type: 'VALID'
};
outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth);
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
} else {
throw Error("Unknown padding parameter: " + pad);
}
return {
padInfo: padInfo,
outDepth: outDepth,
outHeight: outHeight,
outWidth: outWidth
};
}
/**
* Rounds a value depending on the rounding mode
* @param value
* @param roundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function round(value, roundingMode) {
if (!roundingMode) {
return Math.trunc(value);
}
switch (roundingMode) {
case 'round':
// used for Caffe Conv
return Math.round(value);
case 'ceil':
// used for Caffe Pool
return Math.ceil(value);
case 'floor':
return Math.floor(value);
default:
throw new Error("Unknown roundingMode " + roundingMode);
}
}
function tupleValuesAreOne(param) {
var _parseTupleParam4 = parseTupleParam(param),
dimA = _parseTupleParam4[0],
dimB = _parseTupleParam4[1],
dimC = _parseTupleParam4[2];
return dimA === 1 && dimB === 1 && dimC === 1;
}
function eitherStridesOrDilationsAreOne(strides, dilations) {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
/**
* Convert Conv2D dataFormat from 'NHWC'|'NCHW' to
* 'channelsLast'|'channelsFirst'
* @param dataFormat in 'NHWC'|'NCHW' mode
* @return dataFormat in 'channelsLast'|'channelsFirst' mode
* @throws unknown dataFormat
*/
function convertConv2DDataFormat(dataFormat) {
if (dataFormat === 'NHWC') {
return 'channelsLast';
} else if (dataFormat === 'NCHW') {
return 'channelsFirst';
} else {
throw new Error("Unknown dataFormat " + dataFormat);
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reshapes a `tf.Tensor` to a given shape.
*
* Given an input tensor, returns a new tensor with the same values as the
* input tensor with shape `shape`.
*
* If one component of shape is the special value -1, the size of that
* dimension is computed so that the total size remains constant. In
* particular, a shape of [-1] flattens into 1-D. At most one component of
* shape can be -1.
*
* If shape is 1-D or higher, then the operation returns a tensor with shape
* shape filled with the values of tensor. In this case, the number of
* elements implied by shape must be the same as the number of elements in
* tensor.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* x.reshape([2, 2]).print();
* ```
*
* @param x The input tensor to be reshaped.
* @param shape An array of integers defining the output tensor shape.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function reshape_(x, shape) {
var $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
var inputs = {
x: $x
};
var attrs = {
shape: shape
};
return ENGINE.runKernel(Reshape, inputs, attrs);
}
var reshape = op({
reshape_: reshape_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the 2D average pooling of an image.
*
* @param x The input tensor, of rank 4 or rank 3 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param filterSize The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param pad The type of padding algorithm:
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function avgPool_(x, filterSize, strides, pad, dimRoundingMode) {
var $x = convertToTensor(x, 'x', 'avgPool', 'float32');
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in avgPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in avgPool: x must be rank 4 but got rank " + x4D.rank + ".";
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in avgPool: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x4D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(AvgPool, inputs, attrs);
res = cast(res, $x.dtype);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var avgPool = op({
avgPool_: avgPool_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the 3D average pooling.
*
* ```js
* const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
* const result = tf.avgPool3d(x, 2, 1, 'valid');
* result.print();
* ```
*
* @param x The input tensor, of rank 5 or rank 4 of shape
* `[batch, depth, height, width, inChannels]`.
* @param filterSize The filter size:
* `[filterDepth, filterHeight, filterWidth]`.
* If `filterSize` is a single number,
* then `filterDepth == filterHeight == filterWidth`.
* @param strides The strides of the pooling:
* `[strideDepth, strideHeight, strideWidth]`.
* If `strides` is a single number,
* then `strideDepth == strideHeight == strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1*1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
* @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
* "NDHWC". Specify the data format of the input and output data. With the
* default format "NDHWC", the data is stored in the order of: [batch,
* depth, height, width, channels]. Only "NDHWC" is currently supported.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat) {
if (dataFormat === void 0) {
dataFormat = 'NDHWC';
}
var $x = convertToTensor(x, 'x', 'avgPool3d', 'float32');
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function () {
return "Error in avgPool3d: x must be rank 5 but got rank " + x5D.rank + ".";
});
assert(dataFormat === 'NDHWC', function () {
return "Error in avgPool3d: Only NDHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in avgPool3d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x5D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode,
dataFormat: dataFormat
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(AvgPool3D, inputs, attrs);
res = cast(res, x5D.dtype);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var avgPool3d = op({
avgPool3d_: avgPool3d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Concatenates a list of `tf.Tensor`s along a given axis.
*
* The tensors ranks and types must match, and their sizes must match in all
* dimensions except `axis`.
*
* Also available are stricter rank-specific methods that assert that
* `tensors` are of the given rank:
* - `tf.concat1d`
* - `tf.concat2d`
* - `tf.concat3d`
* - `tf.concat4d`
*
* Except `tf.concat1d` (which does not have axis param), all methods have
* same signature as this method.
*
* ```js
* const a = tf.tensor1d([1, 2]);
* const b = tf.tensor1d([3, 4]);
* a.concat(b).print(); // or a.concat(b)
* ```
*
* ```js
* const a = tf.tensor1d([1, 2]);
* const b = tf.tensor1d([3, 4]);
* const c = tf.tensor1d([5, 6]);
* tf.concat([a, b, c]).print();
* ```
*
* ```js
* const a = tf.tensor2d([[1, 2], [10, 20]]);
* const b = tf.tensor2d([[3, 4], [30, 40]]);
* const axis = 1;
* tf.concat([a, b], axis).print();
* ```
* @param tensors A list of tensors to concatenate.
* @param axis The axis to concate along. Defaults to 0 (the first dim).
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function concat_(tensors, axis) {
if (axis === void 0) {
axis = 0;
}
assert(tensors.length >= 1, function () {
return 'Pass at least one tensor to concat';
});
var $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
if ($tensors[0].dtype === 'complex64') {
$tensors.forEach(function (tensor) {
if (tensor.dtype !== 'complex64') {
throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". ");
}
});
}
if ($tensors.length === 1) {
return clone($tensors[0]);
}
var inputs = $tensors;
var attr = {
axis: axis
};
return ENGINE.runKernel(Concat, inputs, attr);
}
var concat = op({
concat_: concat_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes sigmoid element-wise, `1 / (1 + exp(-x))`
*
* ```js
* const x = tf.tensor1d([0, -1, 2, -3]);
*
* x.sigmoid().print(); // or tf.sigmoid(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function sigmoid_(x) {
var $x = convertToTensor(x, 'x', 'sigmoid');
var inputs = {
x: $x
};
return ENGINE.runKernel(Sigmoid, inputs);
}
var sigmoid = op({
sigmoid_: sigmoid_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a slice from a `tf.Tensor` starting at coordinates `begin`
* and is of size `size`.
*
* Also available are stricter rank-specific methods with the same signature
* as this method that assert that `x` is of the given rank:
* - `tf.slice1d`
* - `tf.slice2d`
* - `tf.slice3d`
* - `tf.slice4d`
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.slice([1], [2]).print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* x.slice([1, 0], [1, 2]).print();
* ```
* @param x The input `tf.Tensor` to slice from.
* @param begin The coordinates to start the slice from. The length can be
* less than the rank of x - the rest of the axes will have implicit 0 as
* start. Can also be a single number, in which case it specifies the
* first axis.
* @param size The size of the slice. The length can be less than the rank of
* x - the rest of the axes will have implicit -1. A value of -1 requests
* the rest of the dimensions in the axis. Can also be a single number,
* in which case it specifies the size of the first axis.
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function slice_(x, begin, size) {
var $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
if ($x.rank === 0) {
throw new Error('Slicing scalar is not possible');
}
var inputs = {
x: $x
};
var attrs = {
begin: begin,
size: size
};
return ENGINE.runKernel(Slice, inputs, attrs);
}
var slice$2 = op({
slice_: slice_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, 70]);
*
* x.tanh().print(); // or tf.tanh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function tanh_(x) {
var $x = convertToTensor(x, 'x', 'tanh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Tanh, inputs);
}
var tanh$1 = op({
tanh_: tanh_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the next state and output of a BasicLSTMCell.
*
* Returns `[newC, newH]`.
*
* Derived from tf.contrib.rnn.BasicLSTMCell.
*
* @param forgetBias Forget bias for the cell.
* @param lstmKernel The weights for the cell.
* @param lstmBias The bias for the cell.
* @param data The input to the cell.
* @param c Previous cell state.
* @param h Previous cell output.
*
* @doc {heading: 'Operations', subheading: 'RNN'}
*/
function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) {
var $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell');
var $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell');
var $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell');
var $data = convertToTensor(data, 'data', 'basicLSTMCell');
var $c = convertToTensor(c, 'c', 'basicLSTMCell');
var $h = convertToTensor(h, 'h', 'basicLSTMCell');
var combined = concat([$data, $h], 1);
var weighted = matMul(combined, $lstmKernel);
var res = add$1(weighted, $lstmBias); // i = input_gate, j = new_input, f = forget_gate, o = output_gate
var batchSize = res.shape[0];
var sliceCols = res.shape[1] / 4;
var sliceSize = [batchSize, sliceCols];
var i = slice$2(res, [0, 0], sliceSize);
var j = slice$2(res, [0, sliceCols], sliceSize);
var f = slice$2(res, [0, sliceCols * 2], sliceSize);
var o = slice$2(res, [0, sliceCols * 3], sliceSize);
var newC = add$1(mul(sigmoid(i), tanh$1(j)), mul($c, sigmoid(add$1($forgetBias, f))));
var newH = mul(tanh$1(newC), sigmoid(o));
return [newC, newH];
}
var basicLSTMCell = op({
basicLSTMCell_: basicLSTMCell_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
* shape `blockShape + [batch]`, interleaves these blocks back into the grid
* defined by the spatial dimensions `[1, ..., M]`, to obtain a result with
* the same rank as the input. The spatial dimensions of this intermediate
* result are then optionally cropped according to `crops` to produce the
* output. This is the reverse of `tf.spaceToBatchND`. See below for a precise
* description.
*
* ```js
* const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]);
* const blockShape = [2, 2];
* const crops = [[0, 0], [0, 0]];
*
* x.batchToSpaceND(blockShape, crops).print();
* ```
*
* @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
* remainingShape`, where spatialShape has `M` dimensions.
* @param blockShape A 1-D array. Must have shape `[M]`, all values must
* be >= 1.
* @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0.
* `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input
* dimension `i + 1`, which corresponds to spatial dimension `i`. It is required
* that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]`
*
* This operation is equivalent to the following steps:
*
* 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ...,
* blockShape[M-1], batch / prod(blockShape), x.shape[1], ...,
* x.shape[N-1]]`
*
* 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch /
* prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M],
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch /
* prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] *
* blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]`
*
* 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted`
* according to `crops` to produce the output of shape: `[batch /
* prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1],
* ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] -
* crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]`
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function batchToSpaceND_(x, blockShape, crops) {
var $x = convertToTensor(x, 'x', 'batchToSpaceND');
var prod = blockShape.reduce(function (a, b) {
return a * b;
});
assert($x.rank >= 1 + blockShape.length, function () {
return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length;
});
assert(crops.length === blockShape.length, function () {
return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length;
});
assert($x.shape[0] % prod === 0, function () {
return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " + ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod);
});
var inputs = {
x: $x
};
var attrs = {
blockShape: blockShape,
crops: crops
};
return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
}
var batchToSpaceND = op({
batchToSpaceND_: batchToSpaceND_
});
function xAs4D(x) {
var x4D;
if (x.rank === 0 || x.rank === 1) {
x4D = reshape(x, [1, 1, 1, x.size]);
} else if (x.rank === 2) {
x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
} else if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
} else {
x4D = x;
}
return x4D;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Batch normalization.
*
* As described in
* [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167).
*
* Mean, variance, scale, and offset can be of two shapes:
* - The same shape as the input.
* - In the common case, the depth dimension is the last dimension of x, so
* the values would be an `tf.Tensor1D` of shape [depth].
*
* Also available are stricter rank-specific methods with the same signature
* as this method that assert that parameters passed are of given rank
* - `tf.batchNorm2d`
* - `tf.batchNorm3d`
* - `tf.batchNorm4d`
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) {
if (varianceEpsilon == null) {
varianceEpsilon = 0.001;
}
var $x = convertToTensor(x, 'x', 'batchNorm');
var $mean = convertToTensor(mean, 'mean', 'batchNorm');
var $variance = convertToTensor(variance, 'variance', 'batchNorm');
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
assert($mean.rank === $variance.rank, function () {
return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
});
assert($offset == null || $mean.rank === $offset.rank, function () {
return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
});
assert($scale == null || $mean.rank === $scale.rank, function () {
return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
});
var x4D = xAs4D($x);
var inputs = {
x: x4D,
scale: $scale,
offset: $offset,
mean: $mean,
variance: $variance
};
var attrs = {
varianceEpsilon: varianceEpsilon
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(FusedBatchNorm, inputs, attrs);
return reshape(res, $x.shape);
}
var batchNorm = op({
batchNorm_: batchNorm_
});
/**
* Batch normalization, strictly for 2D. For the more relaxed version, see
* `tf.batchNorm`.
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*/
function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, 'x', 'batchNorm');
var $mean = convertToTensor(mean, 'mean', 'batchNorm');
var $variance = convertToTensor(variance, 'variance', 'batchNorm');
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
assert($x.rank === 2, function () {
return "Error in batchNorm2D: x must be rank 2 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 2 || $mean.rank === 1, function () {
return "Error in batchNorm2D: mean must be rank 2 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 2 || $variance.rank === 1, function () {
return "Error in batchNorm2D: variance must be rank 2 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 2 || $scale.rank === 1, function () {
return "Error in batchNorm2D: scale must be rank 2 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 2 || $offset.rank === 1, function () {
return "Error in batchNorm2D: offset must be rank 2 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm2d = op({
batchNorm2d_: batchNorm2d_
});
/**
* Batch normalization, strictly for 3D. For the more relaxed version, see
* `tf.batchNorm`.
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*/
function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, 'x', 'batchNorm');
var $mean = convertToTensor(mean, 'mean', 'batchNorm');
var $variance = convertToTensor(variance, 'variance', 'batchNorm');
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
assert($x.rank === 3, function () {
return "Error in batchNorm3D: x must be rank 3 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 3 || $mean.rank === 1, function () {
return "Error in batchNorm3D: mean must be rank 3 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 3 || $variance.rank === 1, function () {
return "Error in batchNorm3D: variance must be rank 3 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 3 || $scale.rank === 1, function () {
return "Error in batchNorm3D: scale must be rank 3 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 3 || $offset.rank === 1, function () {
return "Error in batchNorm3D: offset must be rank 3 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm3d = op({
batchNorm3d_: batchNorm3d_
});
/**
* Batch normalization, strictly for 4D. For the more relaxed version, see
* `tf.batchNorm`.
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*/
function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) {
var $x = convertToTensor(x, 'x', 'batchNorm');
var $mean = convertToTensor(mean, 'mean', 'batchNorm');
var $variance = convertToTensor(variance, 'variance', 'batchNorm');
var $scale;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
var $offset;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
assert($x.rank === 4, function () {
return "Error in batchNorm4D: x must be rank 4 but got rank " + ($x.rank + ".");
});
assert($mean.rank === 4 || $mean.rank === 1, function () {
return "Error in batchNorm4D: mean must be rank 4 or rank 1 but " + ("got rank " + $mean.rank + ".");
});
assert($variance.rank === 4 || $variance.rank === 1, function () {
return "Error in batchNorm4D: variance must be rank 4 or rank 1 " + ("but got rank " + $variance.rank + ".");
});
if ($scale != null) {
assert($scale.rank === 4 || $scale.rank === 1, function () {
return "Error in batchNorm4D: scale must be rank 4 or rank 1 " + ("but got rank " + $scale.rank + ".");
});
}
if ($offset != null) {
assert($offset.rank === 4 || $offset.rank === 1, function () {
return "Error in batchNorm4D: offset must be rank 4 or rank 1 " + ("but got rank " + $offset.rank + ".");
});
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}
var batchNorm4d = op({
batchNorm4d_: batchNorm4d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Outputs a vector with length `size` and the same dtype as `weights`.
*
* If `weights` are empty, then index `i` stores the number of times the value
* `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
* sum of the value in `weights` at each index where the corresponding value in
* `x` is `i`.
*
* Values in `x` outside of the range [0, size) are ignored.
*
* @param x The input int tensor, rank 1.
* @param weights The weights tensor, must have the same shape as x, or a
* length-0 Tensor, in which case it acts as all weights equal to 1.
* @param size Non-negative integer.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function bincount_(x, weights, size) {
var $x = convertToTensor(x, 'x', 'bincount');
var $weights = convertToTensor(weights, 'weights', 'bincount');
assert($x.dtype === 'int32', function () {
return "Error in bincount: input " + ("dtype must be int32, but got " + $x.dtype);
});
assert(size >= 0, function () {
return "size must be non-negative, but got " + size + ".";
});
assert($weights.size === $x.size || $weights.size === 0, function () {
return "Error in bincount: weights must have the same size as input or" + ("0-length, but got input shape: " + $x.shape + ", weights shape: ") + ($weights.shape + ".");
});
var inputs = {
x: $x,
weights: $weights
};
var attrs = {
size: size
};
return ENGINE.runKernel(Bincount, inputs, attrs);
}
var bincount = op({
bincount_: bincount_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Return the shape of s0 op s1 with broadcast.
*
* compute r0, the broadcasted shape as a tensor.
* s0, s1 and r0 are all integer vectors.
*
* This function returns the shape of the result of an operation between
* two tensors of size s0 and s1 performed with broadcast.
*
* @param s0 A tensor representing a shape
* @param s1 A tensor representing a shape
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function broadcastArgs_(s0, s1) {
var shape1Input = convertToTensor(s0, 's0', 'broadcastArgs', 'int32');
var shape2Input = convertToTensor(s1, 's1', 'broadcastArgs', 'int32');
if (shape1Input.rank !== 1) {
throw new Error('broadcastArgs(): first input must be a vector (rank=1). ' + ("Has rank " + shape1Input.rank));
}
if (shape2Input.rank !== 1) {
throw new Error('broadcastArgs(): second input must be a vector (rank=1). ' + ("Has rank " + shape2Input.rank));
}
var inputs = {
s0: shape1Input,
s1: shape2Input
};
return ENGINE.runKernel(BroadcastArgs, inputs);
}
var broadcastArgs = op({
broadcastArgs_: broadcastArgs_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function broadcastTo_(x, shape) {
var input = convertToTensor(x, 'broadcastTo', 'x');
var xShape = input.shape;
if (shape.some(function (d) {
return !(d > 0) || d % 1 !== 0;
})) {
throw new Error("broadcastTo(): Invalid broadcast shape [" + shape + "].");
}
if (shape.length < input.rank) {
throw new Error("broadcastTo(): shape.length=" + shape.length + " < input.rank=" + input.rank + ".");
}
if (shape.length > input.rank) {
var newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = reshape(input, newShape);
}
var inputShape = input.shape;
var reps = Array.from(shape);
for (var i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error("broadcastTo(): [" + xShape + "] cannot be broadcast to [" + shape + "].");
}
}
var axes = reps.map(function (n, i) {
return n > 1 ? i : -1;
}).filter(function (i) {
return i >= 0;
});
if (axes.length === 0) {
return clone(input);
} // TODO call broadcastTo kernel directly once backends implement broadcstTo
var inputs = {
x: input
};
var attrs = {
reps: reps
};
return ENGINE.runKernel(Tile, inputs, attrs);
}
var broadcastTo = op({
broadcastTo_: broadcastTo_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)`
*
* ```js
* const x = tf.tensor1d([.6, 1.1, -3.3]);
*
* x.ceil().print(); // or tf.ceil(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function ceil_(x) {
var $x = convertToTensor(x, 'x', 'ceil');
var inputs = {
x: $x
};
return ENGINE.runKernel(Ceil, inputs);
}
var ceil$3 = op({
ceil_: ceil_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)`
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
*
* x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3)
* ```
* @param x The input tensor.
* @param clipValueMin Lower-bound of range to be clipped to.
* @param clipValueMax Upper-bound of range to be clipped to.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function clipByValue_(x, clipValueMin, clipValueMax) {
var $x = convertToTensor(x, 'x', 'clipByValue');
assert(clipValueMin <= clipValueMax, function () {
return "Error in clip: min (" + clipValueMin + ") must be " + ("less than or equal to max (" + clipValueMax + ").");
});
var inputs = {
x: $x
};
var attrs = {
clipValueMin: clipValueMin,
clipValueMax: clipValueMax
};
return ENGINE.runKernel(ClipByValue, inputs, attrs);
}
var clipByValue = op({
clipByValue_: clipByValue_
});
/**
* Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details.
*
* For example, if:
* A: shape(3) = |r1, g1, b1|
* B: shape(2) = |r2, g2|
* C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2|
*
* @param tensors A list of`tf.Tensor`s to concatenate.
* @return The concatenated array.
*/
function concat1d_(tensors) {
return concat(tensors, 0
/* axis */
);
}
var concat1d = op({
concat1d_: concat1d_
});
/**
* Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details.
*
* For example, if:
* A: shape(2, 3) = | r1, g1, b1 |
* | r2, g2, b2 |
*
* B: shape(2, 3) = | r3, g3, b3 |
* | r4, g4, b4 |
*
* C = tf.concat2d([A, B], axis)
*
* if axis = 0:
* C: shape(4, 3) = | r1, g1, b1 |
* | r2, g2, b2 |
* | r3, g3, b3 |
* | r4, g4, b4 |
*
* if axis = 1:
* C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 |
* | r2, g2, b2, r4, g4, b4 |
*
*
* @param tensors A list of `tf.Tensor`s to concatenate.
* @param axis The axis to concatenate along.
* @return The concatenated array.
*/
function concat2d_(tensors, axis) {
return concat(tensors, axis);
}
var concat2d = op({
concat2d_: concat2d_
});
/**
* Concatenates a list of `tf.Tensor3D`s along an axis.
* See `concat` for details.
*
* For example, if:
* A: shape(2, 1, 3) = | r1, g1, b1 |
* | r2, g2, b2 |
*
* B: shape(2, 1, 3) = | r3, g3, b3 |
* | r4, g4, b4 |
*
* C = tf.concat3d([A, B], axis)
*
* if axis = 0:
* C: shape(4, 1, 3) = | r1, g1, b1 |
* | r2, g2, b2 |
* | r3, g3, b3 |
* | r4, g4, b4 |
*
* if axis = 1:
* C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 |
* | r2, g2, b2, r4, g4, b4 |
*
* if axis = 2:
* C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 |
* | r2, g2, b2, r4, g4, b4 |
*
* @param tensors A list of`tf.Tensor`s to concatenate.
* @param axis The axis to concate along.
* @return The concatenated array.
*/
function concat3d_(tensors, axis) {
return concat(tensors, axis);
}
var concat3d = op({
concat3d_: concat3d_
});
/**
* Concatenates a list of `tf.Tensor4D`s along an axis.
* See `concat` for details.
*
* @param tensors A list of `tf.Tensor`s to concatenate.
* @param axis The axis to concate along.
* @return The concatenated array.
*/
function concat4d_(tensors, axis) {
return concat(tensors, axis);
}
var concat4d = op({
concat4d_: concat4d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes a 2D convolution over the input x.
*
* @param x The input tensor, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
* assumed.
* @param filter The filter, rank 4, of shape
* `[filterHeight, filterWidth, inDepth, outDepth]`.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels].
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function conv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
if (dilations === void 0) {
dilations = [1, 1];
}
var $x = convertToTensor(x, 'x', 'conv2d');
var $filter = convertToTensor(filter, 'filter', 'conv2d');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + ".";
});
assert($filter.rank === 4, function () {
return "Error in conv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in conv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
assert(inDepth === $filter.shape[2], function () {
return "Error in conv2d: depth of input (" + inDepth + ") must match " + ("input depth for filter " + $filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in conv2D: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var inputs = {
x: x4D,
filter: $filter
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Conv2D, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var conv2d = op({
conv2d_: conv2d_
});
/**
* Computes a 1D convolution over the input x.
*
* @param x The input tensor, of rank 3 or rank 2, of shape
* `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed.
* @param filter The filter, rank 3, of shape
* `[filterWidth, inDepth, outDepth]`.
* @param stride The number of entries by which the filter is moved right at
* each step.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
* the data is stored in the order of [batch, in_width, in_channels]. Only
* "NWC" is currently supported.
* @param dilation The dilation rate in which we sample input values in
* atrous convolution. Defaults to `1`. If it is greater than 1, then
* stride must be `1`.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function conv1d_(x, filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = 'NWC';
}
if (dilation === void 0) {
dilation = 1;
}
var $x = convertToTensor(x, 'x', 'conv1d');
var $filter = convertToTensor(filter, 'filter', 'conv1d');
var x3D = $x;
var reshapedTo3D = false;
if ($x.rank === 2) {
reshapedTo3D = true;
x3D = reshape($x, [1, $x.shape[0], $x.shape[1]]);
}
assert(x3D.rank === 3, function () {
return "Error in conv1d: input must be rank 3, but got rank " + x3D.rank + ".";
});
assert($filter.rank === 3, function () {
return "Error in conv1d: filter must be rank 3, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in conv1d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
assert(x3D.shape[2] === $filter.shape[1], function () {
return "Error in conv1d: depth of input (" + x3D.shape[2] + ") must match " + ("input depth for filter " + $filter.shape[1] + ".");
});
assert(eitherStridesOrDilationsAreOne(stride, dilation), function () {
return 'Error in conv1D: Either stride or dilation must be 1. ' + ("Got stride " + stride + " and dilation '" + dilation + "'");
});
assert(dataFormat === 'NWC', function () {
return "Error in conv1d: got dataFormat of " + dataFormat + " but only NWC is currently supported.";
});
var filter4D = reshape($filter, [1, $filter.shape[0], $filter.shape[1], $filter.shape[2]]);
var input4D = reshape(x3D, [x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]]);
var strides = [1, stride];
var dilations = [1, dilation];
var conv2dDataFormat = 'NHWC';
var res = conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode);
if (reshapedTo3D) {
return reshape(res, [res.shape[2], res.shape[3]]);
}
return reshape(res, [res.shape[0], res.shape[2], res.shape[3]]);
}
var conv1d = op({
conv1d_: conv1d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the derivative of the input of a 2D convolution.
*
* @param xShape The shape of the input: [batch, height, width, inDepth].
* If length of 3, batch of 1 is assumed.
* @param dy The derivative of the output, of rank 4 or rank 3 of shape
* `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is
* assumed.
* @param filter The filter, rank 4, of shape
* `[filterHeight, filterWidth, inDepth, outDepth]`.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`.
* @param pad The type of padding algorithm used:
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels].
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
assert(xShape.length === dy.rank, function () {
return "Length of inShape " + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match");
});
var xShape4D = xShape;
var dy4D = dy;
var reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
xShape4D = [1, xShape[0], xShape[1], xShape[2]];
}
assert(xShape4D.length === 4, function () {
return "Error in conv2dDerInput: inShape must be length 4, but got length " + (xShape4D.length + ".");
});
assert(dy4D.rank === 4, function () {
return "Error in conv2dDerInput: dy must be rank 4, but got " + ("rank " + dy4D.rank);
});
assert(filter.rank === 4, function () {
return "Error in conv2dDerInput: filter must be rank 4, but got " + ("rank " + filter.rank);
});
var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
assert(inDepth === filter.shape[2], function () {
return "Error in conv2dDerInput: depth of input (" + inDepth + ") must " + ("match input depth for filter " + filter.shape[2] + ".");
});
assert(outDepth === filter.shape[3], function () {
return "Error in conv2dDerInput: depth of output (" + outDepth + ") must " + ("match output depth for filter " + filter.shape[3] + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in conv2dDerInput: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
dy: dy4D,
filter: filter
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dimRoundingMode: dimRoundingMode,
inputShape: xShape4D
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var conv2DBackpropInput = op({
conv2DBackpropInput_: conv2DBackpropInput_
});
/**
* Computes the transposed 2D convolution of an image, also known as a
* deconvolution.
*
* @param x The input image, of rank 4 or rank 3, of shape
* `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed.
* @param filter The filter, rank 4, of shape
* `[filterHeight, filterWidth, outDepth, inDepth]`.
* `inDepth` must match `inDepth` in `x`.
* @param outputShape Output shape, of rank 4 or rank 3:
* `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed.
* @param strides The strides of the original convolution:
* `[strideHeight, strideWidth]`.
* @param pad The type of padding algorithm used in the non-transpose version
* of the op.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) {
var $x = convertToTensor(x, 'x', 'conv2dTranspose');
var $filter = convertToTensor(filter, 'filter', 'conv2dTranspose');
return conv2DBackpropInput(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode);
}
var conv2dTranspose = op({
conv2dTranspose_: conv2dTranspose_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes a 3D convolution over the input x.
*
* @param x The input tensor, of rank 5 or rank 4, of shape
* `[batch, depth, height, width, channels]`. If rank 4,
* batch of 1 is assumed.
* @param filter The filter, rank 5, of shape
* `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`.
* inChannels must match between input and filter.
* @param strides The strides of the convolution: `[strideDepth, strideHeight,
* strideWidth]`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
* "NDHWC". Specify the data format of the input and output data. With the
* default format "NDHWC", the data is stored in the order of: [batch,
* depth, height, width, channels]. Only "NDHWC" is currently supported.
* @param dilations The dilation rates: `[dilationDepth, dilationHeight,
* dilationWidth]` in which we sample input values across the height
* and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
* If `dilations` is a single number, then
* `dilationDepth == dilationHeight == dilationWidth`. If it is greater
* than 1, then all values of `strides` must be 1.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function conv3d_(x, filter, strides, pad, dataFormat, dilations) {
if (dataFormat === void 0) {
dataFormat = 'NDHWC';
}
if (dilations === void 0) {
dilations = [1, 1, 1];
}
var $x = convertToTensor(x, 'x', 'conv3d');
var $filter = convertToTensor(filter, 'filter', 'conv3d');
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function () {
return "Error in conv3d: input must be rank 5, but got rank " + x5D.rank + ".";
});
assert($filter.rank === 5, function () {
return "Error in conv3d: filter must be rank 5, but got rank " + ($filter.rank + ".");
});
assert(x5D.shape[4] === $filter.shape[3], function () {
return "Error in conv3d: depth of input (" + x5D.shape[4] + ") must match " + ("input depth for filter " + $filter.shape[3] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in conv3D: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
assert(dataFormat === 'NDHWC', function () {
return "Error in conv3d: got dataFormat of " + dataFormat + " but only NDHWC is currently supported.";
});
var inputs = {
x: x5D,
filter: $filter
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Conv3D, inputs, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var conv3d = op({
conv3d_: conv3d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the derivative of the input of a 3D convolution.
*
* @param xShape The shape of the input: [batch, depth, height, width,
* in_channels]. If length of 4, batch of 1 is assumed.
* @param dy The derivative of the output, of rank 5 or rank 4 of shape
* `[batch, outDepth, outHeight, outWidth, in_channels]`.
* If rank 4, batch of 1 is assumed.
* @param filter The filter, rank 5, of shape
* `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`.
* @param strides The strides of the convolution: `[strideDepth, strideHeight,
* strideWidth]`.
* @param pad The type of padding algorithm used:
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
*/
function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
assert(xShape.length === dy.rank, function () {
return "Length of inShape " + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match");
});
var xShape5D = xShape;
var dy5D = dy;
var reshapedTo5D = false;
if (dy.rank === 4) {
reshapedTo5D = true;
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
}
var inDepth = xShape5D[4];
var outDepth = dy5D.shape[4];
assert(xShape5D.length === 5, function () {
return "Error in conv3dDerInput: inShape must be length 5, but got length " + (xShape5D.length + ".");
});
assert(dy5D.rank === 5, function () {
return "Error in conv3dDerInput: dy must be rank 5, but got " + ("rank " + dy5D.rank);
});
assert(filter.rank === 5, function () {
return "Error in conv3dDerInput: filter must be rank 5, but got " + ("rank " + filter.rank);
});
assert(inDepth === filter.shape[3], function () {
return "Error in conv3dDerInput: depth of input (" + inDepth + ") must " + ("match input depth for filter " + filter.shape[3] + ".");
});
assert(outDepth === filter.shape[4], function () {
return "Error in conv3dDerInput: depth of output (" + outDepth + ") must " + ("match output depth for filter " + filter.shape[4] + ".");
});
var inputs = {
dy: dy5D,
filter: filter
};
var attrs = {
pad: pad,
strides: strides,
inputShape: xShape5D
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var conv3DBackpropInput = op({
conv3DBackpropInput_: conv3DBackpropInput_
});
/**
* Computes the transposed 3D convolution of a volume, also known as a
* deconvolution.
*
* @param x The input image, of rank 5 or rank 4, of shape
* `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed.
* @param filter The filter, rank 4, of shape
* `[depth, filterHeight, filterWidth, outDepth, inDepth]`.
* `inDepth` must match `inDepth` in `x`.
* @param outputShape Output shape, of rank 5 or rank 4:
* `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is
* assumed.
* @param strides The strides of the original convolution:
* `[strideDepth, strideHeight, strideWidth]`.
* @param pad The type of padding algorithm used in the non-transpose version
* of the op.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function conv3dTranspose_(x, filter, outputShape, strides, pad) {
var $x = convertToTensor(x, 'x', 'conv3dTranspose');
var $filter = convertToTensor(filter, 'filter', 'conv3dTranspose');
return conv3DBackpropInput(outputShape, $x, $filter, strides, pad);
}
var conv3dTranspose = op({
conv3dTranspose_: conv3dTranspose_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes cos of the input `tf.Tensor` element-wise: `cos(x)`
*
* ```js
* const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
*
* x.cos().print(); // or tf.cos(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function cos_(x) {
var $x = convertToTensor(x, 'x', 'cos');
var inputs = {
x: $x
};
return ENGINE.runKernel(Cos, inputs);
}
var cos = op({
cos_: cos_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.cosh().print(); // or tf.cosh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function cosh_(x) {
var $x = convertToTensor(x, 'x', 'cosh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Cosh, inputs);
}
var cosh = op({
cosh_: cosh_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the cumulative sum of a `tf.Tensor` along `axis`.
*
* ```js
* const x = tf.tensor([1, 2, 3, 4]);
* x.cumsum().print();
* ```
* ```js
* const x = tf.tensor([[1, 2], [3, 4]]);
* x.cumsum().print();
* ```
*
* @param x The input tensor to be summed.
* @param axis The axis along which to sum. Optional. Defaults to 0.
* @param exclusive Whether to perform exclusive cumulative sum. Optional.
* Defaults to false. If set to true then the sum of each tensor entry
* does not include its own value, but only the values previous to it
* along the specified axis.
* @param reverse Whether to sum in the opposite direction. Optional.
* Defaults to false.
*
* @doc {heading: 'Operations', subheading: 'Scan'}
*/
function cumsum_(x, axis, exclusive, reverse) {
if (axis === void 0) {
axis = 0;
}
if (exclusive === void 0) {
exclusive = false;
}
if (reverse === void 0) {
reverse = false;
}
var $x = convertToTensor(x, 'x', 'cumsum');
var inputs = {
x: $x
};
var attrs = {
axis: axis,
exclusive: exclusive,
reverse: reverse
};
return ENGINE.runKernel(Cumsum, inputs, attrs);
}
var cumsum = op({
cumsum_: cumsum_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Outputs a vector with length `size` and the same dtype as `weights`.
*
* If `weights` are empty, then index `i` stores the number of times the value
* `i` is counted in `x`. If `weights` are non-empty, then index `i` stores the
* sum of the value in `weights` at each index where the corresponding value in
* `x` is `i`.
*
* Values in `x` outside of the range [0, size) are ignored.
*
* @param x The input int tensor, rank 1 or rank 2.
* @param weights The weights tensor, must have the same shape as x, or a
* length-0 Tensor, in which case it acts as all weights equal to 1.
* @param size Non-negative integer.
* @param binaryOutput Optional. Whether the kernel should count the appearance
* or number of occurrences. Defaults to False.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function denseBincount_(x, weights, size, binaryOutput) {
if (binaryOutput === void 0) {
binaryOutput = false;
}
var $x = convertToTensor(x, 'x', 'denseBincount');
var $weights = convertToTensor(weights, 'weights', 'denseBincount');
assert($x.dtype === 'int32', function () {
return "Error in denseBincount: input " + ("dtype must be int32, but got " + $x.dtype);
});
assert($x.rank <= 2, function () {
return "Error in denseBincount: input must be at most rank 2, but got " + ("rank " + $x.rank + ".");
});
assert(size >= 0, function () {
return "size must be non-negative, but got " + size + ".";
});
assert($weights.size === $x.size || $weights.size === 0, function () {
return "Error in denseBincount: weights must have the same shape as x or " + ("0-length, but got x shape: " + $x.shape + ", weights shape: ") + ($weights.shape + ".");
});
var inputs = {
x: $x,
weights: $weights
};
var attrs = {
size: size,
binaryOutput: binaryOutput
};
return ENGINE.runKernel(DenseBincount, inputs, attrs);
}
var denseBincount = op({
denseBincount_: denseBincount_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Rearranges data from depth into blocks of spatial data. More specifically,
* this op outputs a copy of the input tensor where values from the `depth`
* dimension are moved in spatial blocks to the `height` and `width` dimensions.
* The attr `blockSize` indicates the input block size and how the data is
* moved.
*
* - Chunks of data of size `blockSize * blockSize` from depth are rearranged
* into non-overlapping blocks of size `blockSize x blockSize`
*
* - The width the output tensor is `inputWidth * blockSize`, whereas the
* height is `inputHeight * blockSize`
*
* - The Y, X coordinates within each block of the output image are determined
* by the high order component of the input channel index
*
* - The depth of the input tensor must be divisible by `blockSize *
* blockSize`
*
* The `dataFormat` attr specifies the layout of the input and output tensors
* with the following options: "NHWC": [ `batch, height, width, channels` ]
* "NCHW": [ `batch, channels, height, width` ]
*
* ```js
* const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]);
* const blockSize = 2;
* const dataFormat = "NHWC";
*
* tf.depthToSpace(x, blockSize, dataFormat).print();
* ```
*
* @param x The input tensor of rank 4
* @param blockSIze An `int` that is `>= 2`. The size of the spatial block
* @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC"
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function depthToSpace_(x, blockSize, dataFormat) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
var $x = convertToTensor(x, 'x', 'depthToSpace');
var inputHeight = dataFormat === 'NHWC' ? $x.shape[1] : $x.shape[2];
var inputWidth = dataFormat === 'NHWC' ? $x.shape[2] : $x.shape[3];
var inputDepth = dataFormat === 'NHWC' ? $x.shape[3] : $x.shape[1];
assert(inputHeight * blockSize >= 0, function () {
return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape;
});
assert(inputWidth * blockSize >= 0, function () {
return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape;
});
assert(inputDepth % (blockSize * blockSize) === 0, function () {
return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape;
});
var inputs = {
x: $x
};
var attrs = {
blockSize: blockSize,
dataFormat: dataFormat
};
return ENGINE.runKernel(DepthToSpace, inputs, attrs);
}
var depthToSpace = op({
depthToSpace_: depthToSpace_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Depthwise 2D convolution.
*
* Given a 4D `input` array and a `filter` array of shape
* `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
* `inChannels` convolutional filters of depth 1, this op applies a
* different filter to each input channel (expanding from 1 channel to
* `channelMultiplier` channels for each), then concatenates the results
* together. The output has `inChannels * channelMultiplier` channels.
*
* See
* [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
* https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
* for more details.
*
* @param x The input tensor, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
* assumed.
* @param filter The filter tensor, rank 4, of shape
* `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`. If strides is a single number, then `strideHeight ==
* strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function depthwiseConv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
if (dilations === void 0) {
dilations = [1, 1];
}
var $x = convertToTensor(x, 'x', 'depthwiseConv2d');
var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in depthwiseConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($filter.rank === 4, function () {
return "Error in depthwiseConv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
assert(x4D.shape[3] === $filter.shape[2], function () {
return "Error in depthwiseConv2d: number of input channels " + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + ("filter " + $filter.shape[2] + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in depthwiseConv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x4D,
filter: $filter
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(DepthwiseConv2dNative, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var depthwiseConv2d = op({
depthwiseConv2d_: depthwiseConv2d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns a diagonal tensor with a given diagonal values.
*
* Given a diagonal, this operation returns a tensor with the diagonal and
* everything else padded with zeros.
*
* Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
* of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* tf.diag(x).print()
* ```
* ```js
* const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
*
* tf.diag(x).print()
* ```
* @param x The input tensor.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function diag_(x) {
var $x = convertToTensor(x, 'x', 'diag');
var inputs = {
x: $x
};
return ENGINE.runKernel(Diag, inputs);
}
var diag = op({
diag_: diag_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the grayscale dilation over the input `x`.
*
* @param x The input tensor, rank 3 or rank 4 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param filter The filter tensor, rank 3, of shape
* `[filterHeight, filterWidth, depth]`.
* @param strides The strides of the sliding window for each dimension of the
* input tensor: `[strideHeight, strideWidth]`.
* If `strides` is a single number,
* then `strideHeight == strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1*1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat Specify the data format of the input and output data.
* Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels].
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations`
* is a single number, then `dilationHeight == dilationWidth`. If it is
* greater than 1, then all values of `strides` must be 1.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function dilation2d_(x, filter, strides, pad, dilations, dataFormat) {
if (dilations === void 0) {
dilations = [1, 1];
}
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
var $x = convertToTensor(x, 'x', 'dilation2d');
var $filter = convertToTensor(filter, 'filter', 'dilation2d');
assert($x.rank === 3 || $x.rank === 4, function () {
return "Error in dilation2d: input must be rank 3 or 4, but got rank " + ($x.rank + ".");
});
assert($filter.rank === 3, function () {
return "Error in dilation2d: filter must be rank 3, but got rank " + ($filter.rank + ".");
});
assert(dataFormat === 'NHWC', function () {
return "Error in dilation2d: Only NHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
reshapedTo4D = true;
}
var inputs = {
x: x4D,
filter: $filter
};
var attrs = {
strides: strides,
pad: pad,
dilations: dilations
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Dilation2D, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var dilation2d = op({
dilation2d_: dilation2d_
});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the dimensions in the input shape that are broadcasted to
* produce the provided output shape.
*
* The returned dimensions are 0-indexed and sorted. An example:
* inShape = [4, 1, 3]
* outShape = [5, 4, 3, 3]
* result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3.
*/
function getBroadcastDims(inShape, outShape) {
var inRank = inShape.length;
var dims = [];
for (var i = 0; i < inRank; i++) {
var dim = inRank - 1 - i;
var a = inShape[dim] || 1;
var b = outShape[outShape.length - 1 - i] || 1;
if (b > 1 && a === 1) {
dims.unshift(dim);
}
}
return dims;
}
/**
* Returns the axes in the output space that should be reduced to produce
* the input space.
*/
function getReductionAxes(inShape, outShape) {
var result = [];
for (var i = 0; i < outShape.length; i++) {
var inDim = inShape[inShape.length - i - 1];
var outAxis = outShape.length - i - 1;
var outDim = outShape[outAxis];
if (inDim == null || inDim === 1 && outDim > 1) {
result.unshift(outAxis);
}
}
return result;
}
function assertAndGetBroadcastShape(shapeA, shapeB) {
var result = [];
var l = Math.max(shapeA.length, shapeB.length);
for (var i = 0; i < l; i++) {
var a = shapeA[shapeA.length - i - 1];
if (a == null) {
a = 1;
}
var b = shapeB[shapeB.length - i - 1];
if (b == null) {
b = 1;
}
if (a === 1) {
result.unshift(b);
} else if (b === 1) {
result.unshift(a);
} else if (a !== b) {
var errMsg = "Operands could not be broadcast together with shapes " + (shapeA + " and " + shapeB + ".");
throw Error(errMsg);
} else {
result.unshift(a);
}
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a == b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.equal(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function equal_(a, b) {
var $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Equal, inputs);
}
var equal = op({
equal_: equal_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the elements, either `a` or `b` depending on the `condition`.
*
* If the condition is true, select from `a`, otherwise select from `b`.
*
* ```js
* const cond = tf.tensor1d([false, false, true], 'bool');
* const a = tf.tensor1d([1 , 2, 3]);
* const b = tf.tensor1d([-1, -2, -3]);
*
* a.where(cond, b).print();
* ```
*
* @param condition The input condition. Must be of dtype bool.
* @param a If `condition` is rank 1, `a` may have a higher rank but
* its first dimension must match the size of `condition`.
* @param b A tensor with the same dtype as `a` and with shape that is
* compatible with `a`.
* @return A tensor with same dtype as `a` and `b`, and shape that is
* broadcastable from `a` and `b`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function where_(condition, a, b) {
var $a = convertToTensor(a, 'a', 'where');
var $b = convertToTensor(b, 'b', 'where');
var $condition = convertToTensor(condition, 'condition', 'where', 'bool'); // TODO: move this logic to forward function when the broadcastTo op is
// implemented in WASM.
// Find the broadcastable shape for $condition, $a, and $b.
var broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
var $broadcastedCondition = broadcastTo($condition, broadcastShape);
var $broadcastedA = broadcastTo($a, broadcastShape);
var $broadcastedB = broadcastTo($b, broadcastShape);
var inputs = {
condition: $broadcastedCondition,
t: $broadcastedA,
e: $broadcastedB
};
return ENGINE.runKernel(Select, inputs);
}
var where = op({
where_: where_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with all elements set to 0 with the same shape as the
* given tensor.
*
* ```js
* const x = tf.tensor([1, 2]);
* tf.zerosLike(x).print();
* ```
*
* @param x The tensor of required shape.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function zerosLike_(x) {
var $x = convertToTensor(x, 'x', 'zerosLike');
var inputs = {
x: $x
};
return ENGINE.runKernel(ZerosLike, inputs);
}
var zerosLike = op({
zerosLike_: zerosLike_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0
* if denominator is 0.
*
*
* ```js
* const a = tf.tensor1d([1, 4, 9, 16]);
* const b = tf.tensor1d([1, 2, 3, 4]);
* const c = tf.tensor1d([0, 0, 0, 0]);
*
* a.divNoNan(b).print(); // or tf.divNoNan(a, b)
* a.divNoNan(c).print(); // or tf.divNoNan(a, c)
* ```
*
* ```js
* // Broadcast div a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(2);
* const c = tf.scalar(0);
*
* a.divNoNan(b).print(); // or tf.divNoNan(a, b)
* a.divNoNan(c).print(); // or tf.divNoNan(a, c)
* ```
*
* @param a The first tensor as the numerator.
* @param b The second tensor as the denominator. Must have the same dtype as
* `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function divNoNan_(a, b) {
// TODO: Make this into its own kernel.
var $a = convertToTensor(a, 'a', 'div');
var $b = convertToTensor(b, 'b', 'div');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var divResult = div($a, $b);
var zeros = zerosLike(divResult);
var bEqualsZero = equal($b, zeros);
return where(bEqualsZero, zeros, divResult);
}
var divNoNan = op({
divNoNan_: divNoNan_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the dot product of two matrices and/or vectors, `t1` and `t2`.
*
* ```js
* const a = tf.tensor1d([1, 2]);
* const b = tf.tensor2d([[1, 2], [3, 4]]);
* const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
*
* a.dot(b).print(); // or tf.dot(a, b)
* b.dot(a).print();
* b.dot(c).print();
* ```
* @param t1 The first tensor in the dot operation.
* @param t2 The second tensor in the dot operation.
*
* @doc {heading: 'Operations', subheading: 'Matrices'}
*/
function dot_(t1, t2) {
var $t1 = convertToTensor(t1, 't1', 'dot');
var $t2 = convertToTensor(t2, 't2', 'dot');
assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), function () {
return "Error in dot: inputs must all be rank 1 or 2, but got ranks " + ($t1.rank + " and " + $t2.rank + ".");
});
var t1Inner = $t1.rank === 1 ? $t1.size : $t1.shape[1];
var t2Inner = $t2.rank === 1 ? $t2.size : $t2.shape[0];
assert(t1Inner === t2Inner, function () {
return "Error in dot: inner dimensions of inputs must match, but got " + (t1Inner + " and " + t2Inner + ".");
});
if ($t1.rank === 1 && $t2.rank === 1) {
var t12D = reshape($t1, [1, -1]);
var t22D = reshape($t2, [-1, 1]);
var t1t2 = matMul(t12D, t22D);
return reshape(t1t2, []);
} else if ($t1.rank === 1 && $t2.rank === 2) {
var _t12D = reshape($t1, [1, -1]);
var _t22D = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
var _t1t = matMul(_t12D, _t22D);
return reshape(_t1t, [_t1t.size]);
} else if ($t1.rank === 2 && $t2.rank === 1) {
var _t22D2 = reshape($t2, [-1, 1]);
var _t1t2 = matMul($t1, _t22D2);
return reshape(_t1t2, [_t1t2.size]);
} else {
var _t22D3 = reshape($t2, [$t2.shape[0], $t2.shape[1]]);
var _t1t3 = matMul($t1, _t22D3);
return _t1t3;
}
}
var dot = op({
dot_: dot_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Tensor contraction over specified indices and outer product.
*
* `einsum` allows defining Tensors by defining their element-wise computation.
* This computation is based on
* [Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation).
*
* Some special cases include:
*
* Matrix multiplication:
* ```js
* const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
* const y = tf.tensor2d([[0, 1], [2, 3], [4, 5]]);
* x.print();
* y.print();
* tf.einsum('ij,jk->ik', x, y).print();
* ```
*
* Dot product:
* ```js
* const x = tf.tensor1d([1, 2, 3]);
* const y = tf.tensor1d([0, 1, 2]);
* x.print();
* y.print();
* tf.einsum('i,i->', x, y).print();
* ```
*
* Batch dot product:
* ```js
* const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
* const y = tf.tensor2d([[0, 1, 2], [3, 4, 5]]);
* x.print();
* y.print();
* tf.einsum('bi,bi->b', x, y).print();
* ```
*
* Outer prouduct:
* ```js
* const x = tf.tensor1d([1, 3, 5]);
* const y = tf.tensor1d([2, 4, 6]);
* x.print();
* y.print();
* tf.einsum('i,j->ij', x, y).print();
* ```
*
* Matrix transpose:
* ```js
* const x = tf.tensor2d([[1, 2], [3, 4]]);
* x.print();
* tf.einsum('ij->ji', x).print();
* ```
*
* Batch matrix transpose:
* ```js
* const x = tf.tensor3d([[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]);
* x.print();
* tf.einsum('bij->bji', x).print();
* ```
*
* Limitations:
*
* This implementation of einsum has the following limitations:
*
* - Does not support >2 input tensors.
* - Does not support duplicate axes for any given input tensor. E.g., equation
* 'ii->' is not suppoted.
* - The `...` notation is not supported.
*
* @param equation a string describing the contraction, in the same format as
* [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html).
* @param tensors the input(s) to contract (each one a Tensor), whose shapes
* should be consistent with equation.
* @returns The output tensor.
*
* @doc {heading: 'Tensors', subheading: 'Matrices'}
*/
function einsum_(equation) {
for (var _len = arguments.length, tensors = new Array(_len > 1 ? _len - 1 : 0), _key = 1; _key < _len; _key++) {
tensors[_key - 1] = arguments[_key];
}
var $tensors = tensors.map(function (t, i) {
return convertToTensor(t, "tensors" + i, 'einsum');
});
var attrs = {
equation: equation
};
return ENGINE.runKernel(Einsum, $tensors, attrs);
}
var einsum = op({
einsum_: einsum_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes exponential linear element-wise: `x > 0 ? x : (e ^ x) - 1`.
*
* ```js
* const x = tf.tensor1d([-1, 1, -3, 2]);
*
* x.elu().print(); // or tf.elu(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function elu_(x) {
var $x = convertToTensor(x, 'x', 'elu');
var inputs = {
x: $x
};
return ENGINE.runKernel(Elu, inputs);
}
var elu = op({
elu_: elu_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes gause error function of the input `tf.Tensor` element-wise:
* `erf(x)`
*
* ```js
* const x = tf.tensor1d([0, .1, -.1, .7]);
*
* x.erf().print(); // or tf.erf(x);
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function erf_(x) {
var $x = convertToTensor(x, 'x', 'erf');
assert($x.dtype === 'int32' || $x.dtype === 'float32', function () {
return 'Input dtype must be `int32` or `float32`.';
});
if ($x.dtype === 'int32') {
$x = cast($x, 'float32');
}
var inputs = {
x: $x
};
return ENGINE.runKernel(Erf, inputs);
}
var erf = op({
erf_: erf_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes exponential of the input `tf.Tensor` element-wise. `e ^ x`
*
* ```js
* const x = tf.tensor1d([1, 2, -3]);
*
* x.exp().print(); // or tf.exp(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function exp_(x) {
var $x = convertToTensor(x, 'x', 'exp');
var inputs = {
x: $x
};
return ENGINE.runKernel(Exp, inputs);
}
var exp$3 = op({
exp_: exp_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns a `tf.Tensor` that has expanded rank, by inserting a dimension
* into the tensor's shape.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* const axis = 1;
* x.expandDims(axis).print();
* ```
*
* @param x The input tensor whose dimensions to be expanded.
* @param axis The dimension index at which to insert shape of `1`. Defaults
* to 0 (the first dimension).
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function expandDims_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
assert(axis <= $x.rank, function () {
return 'Axis must be <= rank of the tensor';
});
var inputs = {
input: $x
};
var attrs = {
dim: axis
};
return ENGINE.runKernel(ExpandDims, inputs, attrs);
}
var expandDims = op({
expandDims_: expandDims_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes exponential of the input `tf.Tensor` minus one element-wise.
* `e ^ x - 1`
*
* ```js
* const x = tf.tensor1d([1, 2, -3]);
*
* x.expm1().print(); // or tf.expm1(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function expm1_(x) {
var $x = convertToTensor(x, 'x', 'expm1');
var inputs = {
x: $x
};
return ENGINE.runKernel(Expm1, inputs);
}
var expm1 = op({
expm1_: expm1_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Construct a tensor by repeating it the number of times given by reps.
*
* This operation creates a new tensor by replicating `input` `reps`
* times. The output tensor's i'th dimension has `input.shape[i] *
* reps[i]` elements, and the values of `input` are replicated
* `reps[i]` times along the i'th dimension. For example, tiling
* `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`.
*
* ```js
* const a = tf.tensor1d([1, 2]);
*
* a.tile([2]).print(); // or a.tile([2])
* ```
*
* ```js
* const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* a.tile([1, 2]).print(); // or a.tile([1, 2])
* ```
* @param x The tensor to tile.
* @param reps Determines the number of replications per dimension.
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function tile_(x, reps) {
var $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
assert($x.rank === reps.length, function () {
return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of reps " + reps + ".");
});
var inputs = {
x: $x
};
var attrs = {
reps: reps
};
return ENGINE.runKernel(Tile, inputs, attrs);
}
var tile = op({
tile_: tile_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Create an identity matrix.
*
* @param numRows Number of rows.
* @param numColumns Number of columns. Defaults to `numRows`.
* @param batchShape If provided, will add the batch shape to the beginning
* of the shape of the returned `tf.Tensor` by repeating the identity
* matrix.
* @param dtype Data type.
* @returns Identity matrix of the specified size and data type, possibly
* with batch repetition if `batchShape` is specified.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function eye_(numRows, numColumns, batchShape, dtype) {
if (dtype === void 0) {
dtype = 'float32';
}
if (numColumns == null) {
numColumns = numRows;
}
var buff = buffer([numRows, numColumns], dtype);
var n = numRows <= numColumns ? numRows : numColumns;
for (var i = 0; i < n; ++i) {
buff.set(1, i, i);
}
var out = reshape(buff.toTensor(), [numRows, numColumns]);
if (batchShape == null) {
return out;
} else {
if (batchShape.length === 1) {
return tile(expandDims(out, 0), [batchShape[0], 1, 1]);
} else if (batchShape.length === 2) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
} else if (batchShape.length === 3) {
// tslint:disable-next-line:no-unnecessary-type-assertion
return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [batchShape[0], batchShape[1], batchShape[2], 1, 1]);
} else {
throw new Error("eye() currently supports only 1D and 2D " + ( // tslint:disable-next-line:no-any
"batchShapes, but received " + batchShape.length + "D."));
}
}
}
var eye = op({
eye_: eye_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` filled with a scalar value.
*
* ```js
* tf.fill([2, 2], 4).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param value The scalar value to fill the tensor with.
* @param dtype The type of an element in the resulting tensor. Defaults to
* 'float'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function fill(shape, value, dtype) {
var attrs = {
shape: shape,
value: value,
dtype: dtype
};
return ENGINE.runKernel(Fill, {}, attrs);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes floor of input `tf.Tensor` element-wise: `floor(x)`.
*
* ```js
* const x = tf.tensor1d([.6, 1.1, -3.3]);
*
* x.floor().print(); // or tf.floor(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function floor_(x) {
var $x = convertToTensor(x, 'x', 'floor');
var inputs = {
x: $x
};
return ENGINE.runKernel(Floor, inputs);
}
var floor$a = op({
floor_: floor_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Gather slices from tensor `x`'s axis `axis` according to `indices`.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* const indices = tf.tensor1d([1, 3, 3], 'int32');
*
* x.gather(indices).print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const indices = tf.tensor1d([1, 1, 0], 'int32');
*
* x.gather(indices).print();
* ```
* @param x The input tensor whose slices to be gathered.
* @param indices The indices of the values to extract.
* @param axis The axis over which to select values. Defaults to 0.
* @param batchDims Optional. The number of batch dimensions. It must be less
* than or equal to rank(indices). Defaults to 0.
* The output tensor will have shape of
* `x.shape[:axis] + indices.shape[batchDims:] + x.shape[axis + 1:]`
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function gather_(x, indices, axis, batchDims) {
if (axis === void 0) {
axis = 0;
}
if (batchDims === void 0) {
batchDims = 0;
}
var $x = convertToTensor(x, 'x', 'gather');
var $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
var inputs = {
x: $x,
indices: $indices
};
var attrs = {
axis: axis,
batchDims: batchDims
};
return ENGINE.runKernel(GatherV2, inputs, attrs);
}
var gather = op({
gather_: gather_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a > b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.greater(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function greater_(a, b) {
var $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Greater, inputs);
}
var greater = op({
greater_: greater_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a >= b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.greaterEqual(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function greaterEqual_(a, b) {
var $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(GreaterEqual, inputs);
}
var greaterEqual = op({
greaterEqual_: greaterEqual_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the imaginary part of a complex (or real) tensor.
*
* Given a tensor input, this operation returns a tensor of type float that is
* the imaginary part of each element in input considered as a complex number.
* If input is real, a tensor of all zeros is returned.
*
* ```js
* const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
* tf.imag(x).print();
* ```
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function imag_(input) {
var $input = convertToTensor(input, 'input', 'imag');
var inputs = {
input: $input
};
return ENGINE.runKernel(Imag, inputs);
}
var imag = op({
imag_: imag_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns which elements of x are finite.
*
* ```js
* const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
*
* x.isFinite().print(); // or tf.isNaN(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function isFinite_(x) {
var $x = convertToTensor(x, 'x', 'isFinite');
var inputs = {
x: $x
};
return ENGINE.runKernel(IsFinite, inputs);
}
var isFinite$1 = op({
isFinite_: isFinite_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns which elements of x are Infinity or -Infinity.
*
* ```js
* const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
*
* x.isInf().print(); // or tf.isNaN(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function isInf_(x) {
var $x = convertToTensor(x, 'x', 'isInf');
var inputs = {
x: $x
};
return ENGINE.runKernel(IsInf, inputs);
}
var isInf = op({
isInf_: isInf_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* RReturns which elements of x are NaN.
*
* ```js
* const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]);
*
* x.isNaN().print(); // or tf.isNaN(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function isNaN_(x) {
var $x = convertToTensor(x, 'x', 'isNaN');
var inputs = {
x: $x
};
return ENGINE.runKernel(IsNan, inputs);
}
var isNaN$1 = op({
isNaN_: isNaN_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes leaky rectified linear element-wise.
*
* See
* [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf](
* http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf)
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
*
* x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1)
* ```
* @param x The input tensor.
* @param alpha The scaling factor for negative values, defaults to 0.2.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function leakyRelu_(x, alpha) {
if (alpha === void 0) {
alpha = 0.2;
}
var $x = convertToTensor(x, 'x', 'leakyRelu');
var inputs = {
x: $x
};
var attrs = {
alpha: alpha
};
return ENGINE.runKernel(LeakyRelu, inputs, attrs);
}
var leakyRelu = op({
leakyRelu_: leakyRelu_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a < b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.less(b).print();
* ```
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function less_(a, b) {
var $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Less, inputs);
}
var less = op({
less_: less_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a <= b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([2, 2, 2]);
*
* a.lessEqual(b).print();
* ```
*
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function lessEqual_(a, b) {
var $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(LessEqual, inputs);
}
var lessEqual = op({
lessEqual_: lessEqual_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Return an evenly spaced sequence of numbers over the given interval.
*
* ```js
* tf.linspace(0, 9, 10).print();
* ```
* @param start The start value of the sequence.
* @param stop The end value of the sequence.
* @param num The number of values to generate.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function linspace(start, stop, num) {
if (num <= 0) {
throw new Error('The number of values should be positive.');
}
var attrs = {
start: start,
stop: stop,
num: num
};
return ENGINE.runKernel(LinSpace, {}, attrs);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Normalizes the activation of a local neighborhood across or within
* channels.
*
* @param x The input tensor. The 4-D input tensor is treated as a 3-D array
* of 1D vectors (along the last dimension), and each vector is
* normalized independently.
* @param depthRadius The number of adjacent channels in the 1D normalization
* window.
* @param bias A constant bias term for the basis.
* @param alpha A scale factor, usually positive.
* @param beta An exponent.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function localResponseNormalization_(x, depthRadius, bias, alpha, beta) {
if (depthRadius === void 0) {
depthRadius = 5;
}
if (bias === void 0) {
bias = 1;
}
if (alpha === void 0) {
alpha = 1;
}
if (beta === void 0) {
beta = 0.5;
}
var $x = convertToTensor(x, 'x', 'localResponseNormalization');
assert($x.rank === 4 || $x.rank === 3, function () {
return "Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank " + $x.rank + ".";
});
assert(isInt(depthRadius), function () {
return "Error in localResponseNormalization: depthRadius must be an " + ("integer but got depthRadius " + depthRadius + ".");
});
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
var inputs = {
x: x4D
};
var attrs = {
depthRadius: depthRadius,
bias: bias,
alpha: alpha,
beta: beta
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(LRN, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
} else {
return res;
}
}
var localResponseNormalization = op({
localResponseNormalization_: localResponseNormalization_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)`
*
* ```js
* const x = tf.tensor1d([1, 2, Math.E]);
*
* x.log().print(); // or tf.log(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function log_(x) {
var $x = convertToTensor(x, 'x', 'log');
var inputs = {
x: $x
};
return ENGINE.runKernel(Log, inputs);
}
var log$a = op({
log_: log_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes natural logarithm of the input `tf.Tensor` plus one
* element-wise: `ln(1 + x)`
*
* ```js
* const x = tf.tensor1d([1, 2, Math.E - 1]);
*
* x.log1p().print(); // or tf.log1p(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function log1p_(x) {
var $x = convertToTensor(x, 'x', 'log1p');
var inputs = {
x: $x
};
return ENGINE.runKernel(Log1p, inputs);
}
var log1p = op({
log1p_: log1p_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Provided `f(x)`, returns another function `g(x, dy?)`, which gives the
* gradient of `f(x)` with respect to `x`.
*
* If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to
* `x` is computed instead. `f(x)` must take a single tensor `x` and return a
* single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead.
*
* ```js
* // f(x) = x ^ 2
* const f = x => x.square();
* // f'(x) = 2x
* const g = tf.grad(f);
*
* const x = tf.tensor1d([2, 3]);
* g(x).print();
* ```
*
* ```js
* // f(x) = x ^ 3
* const f = x => x.pow(tf.scalar(3, 'int32'));
* // f'(x) = 3x ^ 2
* const g = tf.grad(f);
* // f''(x) = 6x
* const gg = tf.grad(g);
*
* const x = tf.tensor1d([2, 3]);
* gg(x).print();
* ```
*
* @param f The function f(x), to compute gradient for.
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function grad(f) {
assert(isFunction(f), function () {
return 'The f passed in grad(f) must be a function';
});
return function (x, dy) {
// x can be of any dtype, thus null as the last argument.
var $x = convertToTensor(x, 'x', 'tf.grad', 'string_or_numeric');
var $dy = dy != null ? convertToTensor(dy, 'dy', 'tf.grad') : null;
return ENGINE.tidy(function () {
var _ENGINE$gradients = ENGINE.gradients(function () {
return f($x);
}, [$x], $dy),
value = _ENGINE$gradients.value,
grads = _ENGINE$gradients.grads;
if ($dy != null) {
assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + 'returned by f(x)');
}
checkGrads(grads);
return grads[0];
});
};
}
/**
* Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`,
* which gives an array of gradients of `f()` with respect to each input
* [`x1`,`x2`,...].
*
* If `dy` is passed when calling `g()`, the gradient of
* `f(x1,...).mul(dy).sum()` with respect to each input is computed instead.
* The provided `f` must take one or more tensors and return a single tensor
* `y`. If `f()` takes a single input, we recommend using `tf.grad` instead.
*
* ```js
* // f(a, b) = a * b
* const f = (a, b) => a.mul(b);
* // df / da = b, df / db = a
* const g = tf.grads(f);
*
* const a = tf.tensor1d([2, 3]);
* const b = tf.tensor1d([-2, -3]);
* const [da, db] = g([a, b]);
* console.log('da');
* da.print();
* console.log('db');
* db.print();
* ```
*
* @param f The function `f(x1, x2,...)` to compute gradients for.
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function grads(f) {
assert(isFunction(f), function () {
return 'The f passed in grads(f) must be a function';
});
return function (args, dy) {
assert(Array.isArray(args), function () {
return 'The args passed in grads(f)(args) must be an array ' + 'of `Tensor`s or `TensorLike`s';
}); // args can be of any dtype, thus null as the last argument.
var $args = convertToTensorArray(args, 'args', 'tf.grads', 'string_or_numeric');
var $dy = dy != null ? convertToTensor(dy, 'dy', 'tf.grads') : null;
return ENGINE.tidy(function () {
var _ENGINE$gradients2 = ENGINE.gradients(function () {
return f.apply(void 0, $args);
}, $args, $dy),
value = _ENGINE$gradients2.value,
grads = _ENGINE$gradients2.grads;
if ($dy != null) {
assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])');
}
checkGrads(grads);
return grads;
});
};
}
/**
* Like `tf.grad`, but also returns the value of `f()`. Useful when `f()`
* returns a metric you want to show.
*
* The result is a rich object with the following properties:
* - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`).
* - value: The value returned by `f(x)`.
*
* ```js
* // f(x) = x ^ 2
* const f = x => x.square();
* // f'(x) = 2x
* const g = tf.valueAndGrad(f);
*
* const x = tf.tensor1d([2, 3]);
* const {value, grad} = g(x);
*
* console.log('value');
* value.print();
* console.log('grad');
* grad.print();
* ```
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function valueAndGrad(f) {
assert(isFunction(f), function () {
return 'The f passed in valueAndGrad(f) must be a function';
});
return function (x, dy) {
assert(x instanceof Tensor, function () {
return 'The x passed in valueAndGrad(f)(x) must be a tensor';
});
assert(dy == null || dy instanceof Tensor, function () {
return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor';
});
var _ENGINE$gradients3 = ENGINE.gradients(function () {
return f(x);
}, [x], dy),
grads = _ENGINE$gradients3.grads,
value = _ENGINE$gradients3.value;
checkGrads(grads);
return {
grad: grads[0],
value: value
};
};
}
/**
* Like `tf.grads`, but returns also the value of `f()`. Useful when `f()`
* returns a metric you want to show.
*
* The result is a rich object with the following properties:
* - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`).
* - value: The value returned by `f(x)`.
*
* ```js
* // f(a, b) = a * b
* const f = (a, b) => a.mul(b);
* // df/da = b, df/db = a
* const g = tf.valueAndGrads(f);
*
* const a = tf.tensor1d([2, 3]);
* const b = tf.tensor1d([-2, -3]);
* const {value, grads} = g([a, b]);
*
* const [da, db] = grads;
*
* console.log('value');
* value.print();
*
* console.log('da');
* da.print();
* console.log('db');
* db.print();
* ```
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function valueAndGrads(f) {
assert(isFunction(f), function () {
return 'The f passed in valueAndGrads(f) must be a function';
});
return function (args, dy) {
assert(Array.isArray(args) && args.every(function (arg) {
return arg instanceof Tensor;
}), function () {
return 'The args passed in valueAndGrads(f)(args) must be array of ' + 'tensors';
});
assert(dy == null || dy instanceof Tensor, function () {
return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor';
});
var res = ENGINE.gradients(function () {
return f.apply(void 0, args);
}, args, dy);
if (dy != null) {
assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + 'match the shape returned by f([x1,...])');
}
checkGrads(res.grads);
return res;
};
}
/**
* Computes and returns the gradient of f(x) with respect to the list of
* trainable variables provided by `varList`. If no list is provided, it
* defaults to all trainable variables.
*
* ```js
* const a = tf.variable(tf.tensor1d([3, 4]));
* const b = tf.variable(tf.tensor1d([5, 6]));
* const x = tf.tensor1d([1, 2]);
*
* // f(a, b) = a * x ^ 2 + b * x
* const f = () => a.mul(x.square()).add(b.mul(x)).sum();
* // df/da = x ^ 2, df/db = x
* const {value, grads} = tf.variableGrads(f);
*
* Object.keys(grads).forEach(varName => grads[varName].print());
* ```
*
* @param f The function to execute. f() should return a scalar.
* @param varList The list of variables to compute the gradients with respect
* to. Defaults to all trainable variables.
* @returns An object with the following keys and values:
* - `value`: The value of the function `f`.
* - `grads`: A map from the names of the variables to the gradients.
* If the `varList` argument is provided explicitly and contains a subset of
* non-trainable variables, this map in the return value will contain keys
* that map the names of the non-trainable variables to `null`.
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function variableGrads(f, varList) {
assert(isFunction(f), function () {
return 'The f passed in variableGrads(f) must be a function';
});
assert(varList == null || Array.isArray(varList) && varList.every(function (v) {
return v instanceof Variable;
}), function () {
return 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables';
});
var specifiedVarList = varList != null;
if (!specifiedVarList) {
// Get all of the trainable variables.
varList = [];
for (var varName in ENGINE.registeredVariables) {
varList.push(ENGINE.registeredVariables[varName]);
}
}
var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) {
return !variable.trainable;
}) : null; // Prune non-trainable variables.
var originalVarCount = varList.length;
varList = varList.filter(function (variable) {
return variable.trainable;
});
assert(varList.length > 0, function () {
return "variableGrads() expects at least one of the input variables to " + ("be trainable, but none of the " + originalVarCount + " variables is ") + "trainable.";
});
var allowNoGradients = true;
var _ENGINE$gradients4 = ENGINE.gradients(f, varList, null, allowNoGradients),
value = _ENGINE$gradients4.value,
grads = _ENGINE$gradients4.grads;
assert(grads.some(function (g) {
return g != null;
}), function () {
return 'Cannot find a connection between any variable and the result of ' + 'the loss function y=f(x). Please make sure the operations that ' + 'use variables are inside the function f passed to minimize().';
});
assert(value.rank === 0, function () {
return "The f passed in variableGrads(f) must return a scalar, but it " + ("returned a rank-" + value.rank + " tensor");
});
var namedGrads = {};
varList.forEach(function (v, i) {
if (grads[i] != null) {
namedGrads[v.name] = grads[i];
}
});
if (specifiedNonTrainable != null) {
// If varList is explicitly provided and contains non-trainable values,
// add them to the returned gradients with `null` values.
specifiedNonTrainable.forEach(function (v) {
return namedGrads[v.name] = null;
});
}
return {
value: value,
grads: namedGrads
};
}
/**
* Overrides the gradient computation of a function `f`.
*
* Takes a function
* `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}`
* and returns another function `g(...inputs)` which takes the same inputs as
* `f`. When called, `g` returns `f().value`. In backward mode, custom gradients
* with respect to each input of `f` are computed using `f().gradFunc`.
*
* The `save` function passsed to `f` should be used for saving tensors needed
* in the gradient. And the `saved` passed to the `gradFunc` is a
* `NamedTensorMap`, which contains those saved tensor.
*
* ```js
* const customOp = tf.customGrad((x, save) => {
* // Save x to make sure it's available later for the gradient.
* save([x]);
* // Override gradient of our custom x ^ 2 op to be dy * abs(x);
* return {
* value: x.square(),
* // Note `saved.x` which points to the `x` we saved earlier.
* gradFunc: (dy, saved) => [dy.mul(saved[0].abs())]
* };
* });
*
* const x = tf.tensor1d([-1, -2, 3]);
* const dx = tf.grad(x => customOp(x));
*
* console.log(`f(x):`);
* customOp(x).print();
* console.log(`f'(x):`);
* dx(x).print();
* ```
*
* @param f The function to evaluate in forward mode, which should return
* `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc`
* returns the custom gradients of `f` with respect to its inputs.
*
* @doc {heading: 'Training', subheading: 'Gradients'}
*/
function customGrad(f) {
return ENGINE.customGrad(f);
}
function checkGrads(grads) {
var numNullGradients = grads.filter(function (g) {
return g == null;
}).length;
if (numNullGradients > 0) {
throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y.");
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes `-1 * x` element-wise.
*
* ```js
* const x = tf.tensor2d([1, 2, -2, 0], [2, 2]);
*
* x.neg().print(); // or tf.neg(x)
* ```
*
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function neg_(x) {
var $x = convertToTensor(x, 'x', 'neg');
var inputs = {
x: $x
};
return ENGINE.runKernel(Neg, inputs);
}
var neg = op({
neg_: neg_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.softplus().print(); // or tf.softplus(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function softplus_(x) {
var $x = convertToTensor(x, 'x', 'softplus');
var inputs = {
x: $x
};
return ENGINE.runKernel(Softplus, inputs);
}
var softplus = op({
softplus_: softplus_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes log sigmoid of the input `tf.Tensor` element-wise:
* `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`.
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.logSigmoid().print(); // or tf.logSigmoid(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function logSigmoid_(x) {
var $x = convertToTensor(x, 'x', 'logSigmoid'); // Use a custom gradient to maintain previous implementation.
// There is no LogSigmoid kernel in TF so we can't use engine.runKernel
// directly
var customOp = customGrad(function (x) {
// TODO(yassogba) we can remove the chained softplus call here only
// after backends have modualrized softplus at which point we can call
// engine runKernel(..., Sotfplus, ...) directly.
var value = neg(softplus(neg(x)));
var gradFunc = function gradFunc(dy) {
var derX = mul(dy, sigmoid(neg(x)));
return derX;
};
return {
value: value,
gradFunc: gradFunc
};
});
return customOp($x);
}
var logSigmoid = op({
logSigmoid_: logSigmoid_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the maximum of elements across dimensions of a `tf.Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
* `axes`. If `keepDims` is true, the reduced dimensions are retained with
* length 1. If `axes` has no entries, all dimensions are reduced, and an
* `tf.Tensor` with a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.max().print(); // or tf.max(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.max(axis).print(); // or tf.max(x, axis)
* ```
*
* @param x The input tensor.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function max_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'max');
var inputs = {
x: $x
};
var attrs = {
reductionIndices: axis,
keepDims: keepDims
};
return ENGINE.runKernel(Max, inputs, attrs);
}
var max$5 = op({
max_: max_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([10, 20, 30, 40]);
* const b = tf.tensor1d([1, 2, 3, 4]);
*
* a.sub(b).print(); // or tf.sub(a, b)
* ```
*
* ```js
* // Broadcast subtract a with b.
* const a = tf.tensor1d([10, 20, 30, 40]);
* const b = tf.scalar(5);
*
* a.sub(b).print(); // or tf.sub(a, b)
* ```
* @param a The first `tf.Tensor` to subtract from.
* @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as
* `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function sub_(a, b) {
var $a = convertToTensor(a, 'a', 'sub');
var $b = convertToTensor(b, 'b', 'sub');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Sub, inputs);
}
var sub = op({
sub_: sub_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the sum of elements across dimensions of a `tf.Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
* `axes`. If `keepDims` is true, the reduced dimensions are retained with
* length 1. If axes has no entries, all dimensions are reduced, and a
* `tf.Tensor` with a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.sum().print(); // or tf.sum(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.sum(axis).print(); // or tf.sum(x, axis)
* ```
*
* @param x The input tensor to compute the sum over. If the dtype is `bool`
* it will be converted to `int32` and the output dtype will be `int32`.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function sum_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'sum');
if ($x.dtype === 'bool') {
$x = cast($x, 'int32');
}
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
};
return ENGINE.runKernel(Sum, inputs, attrs);
}
var sum$1 = op({
sum_: sum_
});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the log softmax.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
*
* a.logSoftmax().print(); // or tf.logSoftmax(a)
* ```
*
* ```js
* const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
*
* a.logSoftmax().print(); // or tf.logSoftmax(a)
* ```
*
* @param logits The logits array.
* @param axis The dimension softmax would be performed on. Defaults to `-1`
* which indicates the last dimension.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function logSoftmax_(logits, axis) {
if (axis === void 0) {
axis = -1;
}
var $logits = convertToTensor(logits, 'logits', 'logSoftmax');
if (axis === -1) {
axis = $logits.rank - 1;
}
if (axis !== $logits.rank - 1) {
throw Error('Log Softmax along a non-last dimension is not yet supported. ' + ("Logits was rank " + $logits.rank + " and axis was " + axis));
} // const forward: ForwardFunc<Tensor> = (backend, save) => {
// const keepDims = true;
// const xMax = max(logits, axis, true);
// const shifted = sub(logits, xMax);
// const value =
// sub(cast(shifted, 'float32'), log(sum(exp(shifted), axis,
// keepDims)));
// save([value]);
// return value;
// };
// Use a custom gradient for numerical stability.
var customOp = customGrad(function (logits, save) {
var keepDims = true;
var xMax = max$5(logits, axis, true);
var shifted = sub(logits, xMax);
var value = sub(cast(shifted, 'float32'), log$a(sum$1(exp$3(shifted), axis, keepDims)));
save([value]);
var gradFunc = function gradFunc(dy, saved) {
var value = saved[0];
var keepDims = true;
var softmax = exp$3(value);
return sub(dy, mul(sum$1(dy, axis, keepDims), softmax));
};
return {
value: value,
gradFunc: gradFunc
};
});
return customOp($logits); // TODO Use Engine.runKernel when CPU/WebGL/WASM backends implement this.
// const inputs: LogSoftmaxInputs = {logits: $logits};
// const attrs: LogSoftmaxAttrs = {axis};
// return ENGINE.runKernel(
// LogSoftmax, inputs as {} as NamedTensorMap,
// attrs as {} as NamedAttrMap);
}
var logSoftmax = op({
logSoftmax_: logSoftmax_
});
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns true if the axis specifies the inner most dimensions of the
* array.
*/
function axesAreInnerMostDims(axes, rank) {
for (var i = 0; i < axes.length; ++i) {
if (axes[axes.length - i - 1] !== rank - 1 - i) {
return false;
}
}
return true;
}
function combineLocations(outputLoc, reduceLoc, axes) {
var rank = outputLoc.length + reduceLoc.length;
var loc = [];
var outIdx = 0;
var reduceIdx = 0;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
loc.push(outputLoc[outIdx++]);
} else {
loc.push(reduceLoc[reduceIdx++]);
}
}
return loc;
}
function computeOutAndReduceShapes(aShape, axes) {
var outShape = [];
var rank = aShape.length;
for (var dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
outShape.push(aShape[dim]);
}
}
var reduceShape = axes.map(function (dim) {
return aShape[dim];
});
return [outShape, reduceShape];
}
function expandShapeToKeepDim(shape, axes) {
var reduceSubShape = axes.map(function (x) {
return 1;
});
return combineLocations(shape, reduceSubShape, axes);
}
function assertAxesAreInnerMostDims(msg, axes, rank) {
assert(axesAreInnerMostDims(axes, rank), function () {
return msg + " supports only inner-most axes for now. " + ("Got axes " + axes + " and rank-" + rank + " input.");
});
}
/**
* Returns the axes permutation to be used with `tf.transpose`, if such
* permutation is necessary. Otherwise it returns null. This method is used by
* operations that operate only on inner-most axes.
*/
function getAxesPermutation(axes, rank) {
if (axesAreInnerMostDims(axes, rank)) {
return null;
}
var result = [];
for (var i = 0; i < rank; ++i) {
if (axes.indexOf(i) === -1) {
result.push(i);
}
}
axes.forEach(function (axis) {
return result.push(axis);
});
return result;
}
/** Returns the axes permutation that undoes the original permutation. */
function getUndoAxesPermutation(axes) {
return axes.map(function (axis, i) {
return [i, axis];
}).sort(function (a, b) {
return a[1] - b[1];
}).map(function (x) {
return x[0];
});
}
function getInnerMostAxes(numAxes, rank) {
var res = [];
for (var i = rank - numAxes; i < rank; ++i) {
res.push(i);
}
return res;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the log(sum(exp(elements across the reduction dimensions)).
*
* Reduces the input along the dimensions given in `axis`. Unless `keepDims`
* is true, the rank of the array is reduced by 1 for each entry in `axis`.
* If `keepDims` is true, the reduced dimensions are retained with length 1.
* If `axis` has no entries, all dimensions are reduced, and an array with a
* single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.logSumExp().print(); // or tf.logSumExp(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.logSumExp(axis).print(); // or tf.logSumExp(a, axis)
* ```
* @param x The input tensor.
* @param axis The dimension(s) to reduce. If null (the default),
* reduces all dimensions.
* @param keepDims If true, retains reduced dimensions with length
* of 1. Defaults to false.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function logSumExp_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'logSumExp');
var axes = parseAxisParam(axis, $x.shape);
var xMax = max$5($x, axes, true
/* keepDims */
);
var a = sub($x, xMax);
var b = exp$3(a);
var c = sum$1(b, axes);
var d = log$a(c);
var res = add$1(reshape(xMax, d.shape), d);
if (keepDims) {
var newShape = expandShapeToKeepDim(res.shape, axes);
return reshape(res, newShape);
}
return res;
}
var logSumExp = op({
logSumExp_: logSumExp_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of `a AND b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalAnd(b).print();
* ```
*
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function logicalAnd_(a, b) {
var $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
var $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(LogicalAnd, inputs);
}
var logicalAnd = op({
logicalAnd_: logicalAnd_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of `NOT x` element-wise.
*
* ```js
* const a = tf.tensor1d([false, true], 'bool');
*
* a.logicalNot().print();
* ```
*
* @param x The input tensor. Must be of dtype 'bool'.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function logicalNot_(x) {
var $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
var inputs = {
x: $x
};
return ENGINE.runKernel(LogicalNot, inputs);
}
var logicalNot = op({
logicalNot_: logicalNot_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of `a OR b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalOr(b).print();
* ```
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function logicalOr_(a, b) {
var $a = convertToTensor(a, 'a', 'logicalOr', 'bool');
var $b = convertToTensor(b, 'b', 'logicalOr', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(LogicalOr, inputs);
}
var logicalOr = op({
logicalOr_: logicalOr_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of `a XOR b` element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([false, false, true, true], 'bool');
* const b = tf.tensor1d([false, true, false, true], 'bool');
*
* a.logicalXor(b).print();
* ```
*
* @param a The first input tensor. Must be of dtype bool.
* @param b The second input tensor. Must be of dtype bool.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function logicalXor_(a, b) {
var $a = convertToTensor(a, 'a', 'logicalXor', 'bool');
var $b = convertToTensor(b, 'b', 'logicalXor', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape); // x ^ y = (x | y) & ~(x & y)
return logicalAnd(logicalOr(a, b), logicalNot(logicalAnd(a, b)));
}
var logicalXor = op({
logicalXor_: logicalXor_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the 2D max pooling of an image.
*
* @param x The input tensor, of rank 4 or rank 3 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param filterSize The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function maxPool_(x, filterSize, strides, pad, dimRoundingMode) {
var $x = convertToTensor(x, 'x', 'maxPool');
var dilations = 1;
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in maxPool: input must be rank 4 but got rank " + x4D.rank + ".";
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in maxPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in maxPool: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x4D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(MaxPool, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var maxPool = op({
maxPool_: maxPool_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the 3D max pooling.
*
* ```js
* const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]);
* const result = tf.maxPool3d(x, 2, 1, 'valid');
* result.print();
* ```
*
* @param x The input tensor, of rank 5 or rank 4 of shape
* `[batch, depth, height, width, inChannels]`.
* @param filterSize The filter size:
* `[filterDepth, filterHeight, filterWidth]`.
* If `filterSize` is a single number,
* then `filterDepth == filterHeight == filterWidth`.
* @param strides The strides of the pooling:
* `[strideDepth, strideHeight, strideWidth]`.
* If `strides` is a single number,
* then `strideDepth == strideHeight == strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1*1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
* @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
* "NDHWC". Specify the data format of the input and output data. With the
* default format "NDHWC", the data is stored in the order of: [batch,
* depth, height, width, channels]. Only "NDHWC" is currently supported.
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function maxPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat) {
if (filterSize === void 0) {
filterSize = [1, 1, 1];
}
if (dataFormat === void 0) {
dataFormat = 'NDHWC';
}
var $x = convertToTensor(x, 'x', 'maxPool3d');
var x5D = $x;
var reshapedTo5D = false;
if ($x.rank === 4) {
reshapedTo5D = true;
x5D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]]);
}
assert(x5D.rank === 5, function () {
return "Error in maxPool3d: x must be rank 5 but got rank " + x5D.rank + ".";
});
assert(dataFormat === 'NDHWC', function () {
return "Error in maxPool3d: Only NDHWC is currently supported, " + ("but got dataFormat of " + dataFormat);
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in maxPool3d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x5D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode,
dataFormat: dataFormat
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(MaxPool3D, inputs, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var maxPool3d = op({
maxPool3d_: maxPool3d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the 2D max pooling of an image with Argmax index.
* The indices in argmax are flattened, so that a maximum value at position `[b,
* y, x, c]` becomes flattened index: `(y * width + x) * channels + c` if
* include_batch_in_index is False; `((b * height + y) * width + x) * channels
* +c` if include_batch_in_index is True.
*
* The indices returned are always in `[0, height) x [0, width)` before
* flattening.
*
* @param x The input tensor, of rank 4 or rank 3 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param filterSize The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to
* "NDHWC". Specify the data format of the input and output data. With the
* default format "NDHWC", the data is stored in the order of: [batch,
* depth, height, width, channels]. Only "NDHWC" is currently supported.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param includeBatchIndex Defaults to False. Whether to include batch
* dimension in flattened index of argmax.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function maxPoolWithArgmax_(x, filterSize, strides, pad, includeBatchInIndex) {
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
var $x = convertToTensor(x, 'x', 'maxPoolWithArgmax');
var inputs = {
x: $x
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
includeBatchInIndex: includeBatchInIndex
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var result = ENGINE.runKernel(MaxPoolWithArgmax, inputs, attrs);
return {
result: result[0],
indexes: result[1]
};
}
var maxPoolWithArgmax = op({
maxPoolWithArgmax_: maxPoolWithArgmax_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the max of a and b (`a > b ? a : b`) element-wise.
* Supports broadcasting.
*
* We also expose `tf.maximumStrict` which has the same signature as this op and
* asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 4, 3, 16]);
* const b = tf.tensor1d([1, 2, 9, 4]);
*
* a.maximum(b).print(); // or tf.maximum(a, b)
* ```
*
* ```js
* // Broadcast maximum a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(5);
*
* a.maximum(b).print(); // or tf.maximum(a, b)
* ```
*
* @param a The first tensor.
* @param b The second tensor. Must have the same type as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function maximum_(a, b) {
var $a = convertToTensor(a, 'a', 'maximum');
var $b = convertToTensor(b, 'b', 'maximum');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
if ($a.dtype === 'bool') {
$a = cast($a, 'int32');
$b = cast($b, 'int32');
}
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Maximum, inputs);
}
var maximum = op({
maximum_: maximum_
});
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the mean of elements across dimensions of a `tf.Tensor`.
*
* Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is
* true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`.
* If `keepDims` is true, the reduced dimensions are retained with length 1.
* If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with
* a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.mean().print(); // or tf.mean(a)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.mean(axis).print(); // or tf.mean(x, axis)
* ```
*
* @param x The input tensor.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function mean_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'mean');
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
};
return ENGINE.runKernel(Mean, inputs, attrs);
}
var mean = op({
mean_: mean_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with all elements set to 0.
*
* ```js
* tf.zeros([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param dtype The type of an element in the resulting tensor. Can
* be 'float32', 'int32' or 'bool'. Defaults to 'float'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function zeros(shape, dtype) {
if (dtype === void 0) {
dtype = 'float32';
}
if (dtype === 'complex64') {
var real = zeros(shape, 'float32');
var imag = zeros(shape, 'float32');
return complex(real, imag);
}
var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with all elements set to 1.
*
* ```js
* tf.ones([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param dtype The type of an element in the resulting tensor. Defaults to
* 'float'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function ones$1(shape, dtype) {
if (dtype === void 0) {
dtype = 'float32';
}
if (dtype === 'complex64') {
var real = ones$1(shape, 'float32');
var imag = zeros(shape, 'float32');
return complex(real, imag);
}
var values = makeOnesTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Broadcasts parameters for evaluation on an N-D grid.
*
* Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
* of N-D coordinate arrays for evaluating expressions on an N-D grid.
*
* Notes:
* `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
* When the `indexing` argument is set to 'xy' (the default), the broadcasting
* instructions for the first two dimensions are swapped.
* Examples:
* Calling `const [X, Y] = meshgrid(x, y)` with the tensors
*
* ```javascript
* const x = [1, 2, 3];
* const y = [4, 5, 6];
* const [X, Y] = tf.meshgrid(x, y);
* // X = [[1, 2, 3],
* // [1, 2, 3],
* // [1, 2, 3]]
* // Y = [[4, 4, 4],
* // [5, 5, 5],
* // [6, 6, 6]]
* ```
*
* @param x Tensor with rank geq 1.
* @param y Tensor with rank geq 1.
* @param indexing
*
* @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
*/
function meshgrid(x, y, _temp) {
var _ref = _temp === void 0 ? {} : _temp,
_ref$indexing = _ref.indexing,
indexing = _ref$indexing === void 0 ? 'xy' : _ref$indexing;
if (indexing !== 'xy' && indexing !== 'ij') {
throw new TypeError(indexing + " is not a valid third argument to meshgrid");
}
if (x === undefined) {
return [];
}
var $x = convertToTensor(x, 'x', 'meshgrid', x instanceof Tensor ? x.dtype : 'float32');
if (y === undefined) {
return [$x];
}
var $y = convertToTensor(y, 'y', 'meshgrid', y instanceof Tensor ? y.dtype : 'float32');
var w = sizeFromShape($x.shape);
var h = sizeFromShape($y.shape);
if (indexing === 'xy') {
$x = reshape($x, [1, -1]);
$y = reshape($y, [-1, 1]);
return [matMul(ones$1([h, 1], $x.dtype), $x), matMul($y, ones$1([1, w], $y.dtype))];
}
$x = reshape($x, [-1, 1]);
$y = reshape($y, [1, -1]);
return [matMul($x, ones$1([1, h], $x.dtype)), matMul(ones$1([w, 1], $y.dtype), $y)];
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the minimum value from the input.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the array is reduced by 1 for each entry in `axes`.
* If `keepDims` is true, the reduced dimensions are retained with length 1.
* If `axes` has no entries, all dimensions are reduced, and an array with a
* single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.min().print(); // or tf.min(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.min(axis).print(); // or tf.min(x, axis)
* ```
*
* @param x The input Tensor.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function min_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'min');
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(Min, inputs, attrs);
}
var min$9 = op({
min_: min_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the min of a and b (`a < b ? a : b`) element-wise.
* Supports broadcasting.
*
* We also expose `minimumStrict` which has the same signature as this op and
* asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 4, 3, 16]);
* const b = tf.tensor1d([1, 2, 9, 4]);
*
* a.minimum(b).print(); // or tf.minimum(a, b)
* ```
*
* ```js
* // Broadcast minimum a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(5);
*
* a.minimum(b).print(); // or tf.minimum(a, b)
* ```
*
* @param a The first tensor.
* @param b The second tensor. Must have the same type as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function minimum_(a, b) {
var $a = convertToTensor(a, 'a', 'minimum');
var $b = convertToTensor(b, 'b', 'minimum');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
if ($a.dtype === 'bool') {
$a = cast($a, 'int32');
$b = cast($b, 'int32');
}
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Minimum, inputs);
}
var minimum = op({
minimum_: minimum_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Pads a `tf.Tensor` using mirror padding.
*
* This operation implements the `REFLECT` and `SYMMETRIC` modes of pad.
*
* ```js
* const x = tf.range(0, 9).reshape([1, 1, 3, 3]);
* x.mirrorPad([[0, 0], [0, 0], [2, 2], [2, 2]], 'reflect').print();
* ```
* @param x The tensor to pad.
* @param paddings An array of length `R` (the rank of the tensor), where
* each element is a length-2 tuple of ints `[padBefore, padAfter]`,
* specifying how much to pad along each dimension of the tensor.
* In "reflect" mode, the padded regions do not include the borders,
* while in "symmetric" mode the padded regions do include the borders.
* For example, if the input is `[1, 2, 3]` and paddings is `[0, 2]`,
* then the output is `[1, 2, 3, 2, 1]` in "reflect" mode, and
* `[1, 2, 3, 3, 2]` in "symmetric" mode.
* If `mode` is "reflect" then both `paddings[D, 0]` and `paddings[D, 1]`
* must be no greater than `x.shape[D] - 1`. If mode is "symmetric"
* then both `paddings[D, 0]` and `paddings[D, 1]` must be no greater than
* `x.shape[D]`
* @param mode String to specify padding mode. Can be `'reflect' | 'symmetric'`
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function mirrorPad_(x, paddings, mode) {
assert(mode === 'reflect' || mode === 'symmetric', function () {
return "Invalid mode. Mode must be either reflect or symmetric. " + ("Got " + mode + ".");
});
var $x = convertToTensor(x, 'x', 'mirrorPad');
if ($x.rank === 0) {
throw new Error('mirrorPad(scalar) is not defined. ' + 'Pass non-scalar to mirrorPad');
}
assert(paddings.length === $x.rank, function () {
return "Padding doesn't match input. Must be " + $x.rank + ". " + ("Got " + paddings.length + ".");
});
var shapeOffset = mode === 'reflect' ? 1 : 0;
var _loop = function _loop(i) {
assert(paddings[i].length === 2, function () {
return "Invalid number of paddings. Must be length of 2 each.";
});
assert(paddings[i][0] >= 0 && paddings[i][0] <= $x.shape[i] - shapeOffset && paddings[i][1] >= 0 && paddings[i][1] <= $x.shape[i] - shapeOffset, function () {
return "Padding in dimension " + i + " cannot be greater than or equal " + ("to " + ($x.shape[i] - shapeOffset) + " or less than 0 for input of ") + ("shape " + $x.shape);
});
};
for (var i = 0; i < $x.rank; i++) {
_loop(i);
}
var attrs = {
paddings: paddings,
mode: mode
};
var inputs = {
x: $x
};
return ENGINE.runKernel(MirrorPad, inputs, attrs);
}
var mirrorPad = op({
mirrorPad_: mirrorPad_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the mod of a and b element-wise.
* `floor(x / y) * y + mod(x, y) = x`
* Supports broadcasting.
*
* We also expose `tf.modStrict` which has the same signature as this op and
* asserts that `a` and `b` are the same shape (does not broadcast).
*
* ```js
* const a = tf.tensor1d([1, 4, 3, 16]);
* const b = tf.tensor1d([1, 2, 9, 4]);
*
* a.mod(b).print(); // or tf.mod(a, b)
* ```
*
* ```js
* // Broadcast a mod b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(5);
*
* a.mod(b).print(); // or tf.mod(a, b)
* ```
*
* @param a The first tensor.
* @param b The second tensor. Must have the same type as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function mod_(a, b) {
var $a = convertToTensor(a, 'a', 'mod');
var $b = convertToTensor(b, 'b', 'mod');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(Mod, inputs);
}
var mod = op({
mod_: mod_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes square of `x` element-wise: `x ^ 2`
*
* ```js
* const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]);
*
* x.square().print(); // or tf.square(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function square_(x) {
var $x = convertToTensor(x, 'x', 'square');
var attrs = {};
return ENGINE.runKernel('Square', {
x: $x
}, attrs);
}
var square = op({
square_: square_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Calculates the mean and variance of `x`. The mean and variance are
* calculated by aggregating the contents of `x` across `axes`. If `x` is
* 1-D and `axes = [0]` this is just the mean and variance of a vector.
*
* @param x The input tensor.
* @param axis The dimension(s) along with to compute mean and
* variance. By default it reduces all dimensions.
* @param keepDims If true, the moments have the same dimensionality as the
* input.
* @return An object with two keys: `mean` and `variance`.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function moments_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
x = convertToTensor(x, 'x', 'moments');
var axes = parseAxisParam(axis, x.shape);
var xMean = mean(x, axes, keepDims);
var keepDimsShape = xMean.shape;
if (!keepDims) {
keepDimsShape = expandShapeToKeepDim(xMean.shape, axes);
}
var devSquared = square(sub(cast(x, 'float32'), reshape(xMean, keepDimsShape)));
var variance = mean(devSquared, axes, keepDims);
return {
mean: xMean,
variance: variance
};
}
var moments = op({
moments_: moments_
});
/**
* Computes the next states and outputs of a stack of LSTMCells.
*
* Each cell output is used as input to the next cell.
*
* Returns `[cellState, cellOutput]`.
*
* Derived from tf.contrib.rn.MultiRNNCell.
*
* @param lstmCells Array of LSTMCell functions.
* @param data The input to the cell.
* @param c Array of previous cell states.
* @param h Array of previous cell outputs.
*
* @doc {heading: 'Operations', subheading: 'RNN'}
*/
function multiRNNCell_(lstmCells, data, c, h) {
var $data = convertToTensor(data, 'data', 'multiRNNCell');
var $c = convertToTensorArray(c, 'c', 'multiRNNCell');
var $h = convertToTensorArray(h, 'h', 'multiRNNCell');
var input = $data;
var newStates = [];
for (var i = 0; i < lstmCells.length; i++) {
var output = lstmCells[i](input, $c[i], $h[i]);
newStates.push(output[0]);
newStates.push(output[1]);
input = output[1];
}
var newC = [];
var newH = [];
for (var _i = 0; _i < newStates.length; _i += 2) {
newC.push(newStates[_i]);
newH.push(newStates[_i + 1]);
}
return [newC, newH];
}
var multiRNNCell = op({
multiRNNCell_: multiRNNCell_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values drawn from a multinomial distribution.
*
* ```js
* const probs = tf.tensor([.75, .25]);
* tf.multinomial(probs, 3).print();
* ```
*
* @param logits 1D array with unnormalized log-probabilities, or
* 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
* parameter.
* @param numSamples Number of samples to draw for each row slice.
* @param seed The seed number.
* @param normalized Whether the provided `logits` are normalized true
* probabilities (sum to 1). Defaults to false.
* @return 1D array of shape `[numSamples]`, or 2D array of shape
* `[batchSize, numSamples]`, depending on the rank of the input.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function multinomial_(logits, numSamples, seed, normalized) {
if (normalized === void 0) {
normalized = false;
}
var $logits = convertToTensor(logits, 'logits', 'multinomial');
var numOutcomes = $logits.size;
var origRank = $logits.rank;
if (numOutcomes < 2) {
throw new Error("Error in multinomial: you need at least 2 outcomes, but got " + (numOutcomes + "."));
}
if (origRank > 2) {
throw new Error("Rank of probabilities must be 1 or 2, but is " + origRank);
} // TODO(lina128): Investigate correct seed behavior. The code seems not allow
// setting see to 0.
seed = seed || Math.random(); // The kernel only accepts (and returns) rank 2 tensors.
var logits2D = origRank === 1 ? reshape($logits, [1, -1]) : $logits;
var inputs = {
logits: logits2D
};
var attrs = {
numSamples: numSamples,
seed: seed,
normalized: normalized
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(Multinomial, inputs, attrs); // tslint:disable-next-line:no-unnecessary-type-assertion
return origRank === 1 ? reshape(res, [res.size]) : res;
}
var multinomial = op({
multinomial_: multinomial_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the truth value of (a != b) element-wise. Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([0, 2, 3]);
*
* a.notEqual(b).print();
* ```
* @param a The first input tensor.
* @param b The second input tensor. Must have the same dtype as `a`.
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function notEqual_(a, b) {
var $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
var $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
return ENGINE.runKernel(NotEqual, inputs);
}
var notEqual = op({
notEqual_: notEqual_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with all elements set to 1 with the same shape as the
* given tensor.
*
* ```js
* const x = tf.tensor([1, 2]);
* tf.onesLike(x).print();
* ```
* @param x A tensor.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function onesLike_(x) {
var $x = convertToTensor(x, 'x', 'onesLike');
var inputs = {
x: $x
};
return ENGINE.runKernel(OnesLike, inputs);
}
var onesLike = op({
onesLike_: onesLike_
});
/**
* Computes the outer product of two vectors, `v1` and `v2`.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
* const b = tf.tensor1d([3, 4, 5]);
*
* tf.outerProduct(a, b).print();
* ```
* @param v1 The first vector in the outer product operation.
* @param v2 The second vector in the outer product operation.
*
* @doc {heading: 'Operations', subheading: 'Matrices'}
*/
function outerProduct_(v1, v2) {
var $v1 = convertToTensor(v1, 'v1', 'outerProduct');
var $v2 = convertToTensor(v2, 'v2', 'outerProduct');
assert($v1.rank === 1 && $v2.rank === 1, function () {
return "Error in outerProduct: inputs must be rank 1, but got ranks " + ($v1.rank + " and " + $v2.rank + ".");
});
var v12D = reshape($v1, [-1, 1]);
var v22D = reshape($v2, [1, -1]);
return matMul(v12D, v22D);
}
var outerProduct = op({
outerProduct_: outerProduct_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Pads a `tf.Tensor` with a given value and paddings.
*
* This operation implements `CONSTANT` mode. For `REFLECT` and `SYMMETRIC`,
* refer to `tf.mirrorPad`
*
* Also available are stricter rank-specific methods with the same signature
* as this method that assert that `paddings` is of given length.
* - `tf.pad1d`
* - `tf.pad2d`
* - `tf.pad3d`
* - `tf.pad4d`
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* x.pad([[1, 2]]).print();
* ```
* @param x The tensor to pad.
* @param paddings An array of length `R` (the rank of the tensor), where
* each element is a length-2 tuple of ints `[padBefore, padAfter]`,
* specifying how much to pad along each dimension of the tensor.
* @param constantValue The pad value to use. Defaults to 0.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function pad_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
var $x = convertToTensor(x, 'x', 'pad');
if ($x.rank === 0) {
throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
}
var attrs = {
paddings: paddings,
constantValue: constantValue
};
var inputs = {
x: $x
};
return ENGINE.runKernel(PadV2, inputs, attrs);
}
var pad = op({
pad_: pad_
});
/**
* Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details.
*/
function pad1d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 2, function () {
return 'Invalid number of paddings. Must be length of 2.';
});
return pad(x, [paddings], constantValue);
}
var pad1d = op({
pad1d_: pad1d_
});
/**
* Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details.
*/
function pad2d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 2 && paddings[0].length === 2 && paddings[1].length === 2, function () {
return 'Invalid number of paddings. Must be length of 2 each.';
});
return pad(x, paddings, constantValue);
}
var pad2d = op({
pad2d_: pad2d_
});
/**
* Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details.
*/
function pad3d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 3 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2, function () {
return 'Invalid number of paddings. Must be length of 2 each.';
});
return pad(x, paddings, constantValue);
}
var pad3d = op({
pad3d_: pad3d_
});
/**
* Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details.
*/
function pad4d_(x, paddings, constantValue) {
if (constantValue === void 0) {
constantValue = 0;
}
assert(paddings.length === 4 && paddings[0].length === 2 && paddings[1].length === 2 && paddings[2].length === 2 && paddings[3].length === 2, function () {
return 'Invalid number of paddings. Must be length of 2 each.';
});
return pad(x, paddings, constantValue);
}
var pad4d = op({
pad4d_: pad4d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* This operation divides "spatial" dimensions `[1, ..., M]` of the input into
* a grid of blocks of shape `blockShape`, and interleaves these blocks with
* the "batch" dimension (0) such that in the output, the spatial
* dimensions `[1, ..., M]` correspond to the position within the grid,
* and the batch dimension combines both the position within a spatial block
* and the original batch position. Prior to division into blocks,
* the spatial dimensions of the input are optionally zero padded
* according to `paddings`. See below for a precise description.
*
* ```js
* const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]);
* const blockShape = [2, 2];
* const paddings = [[0, 0], [0, 0]];
*
* x.spaceToBatchND(blockShape, paddings).print();
* ```
*
* @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape +
* remainingShape`, where spatialShape has `M` dimensions.
* @param blockShape A 1-D array. Must have shape `[M]`, all values must
* be >= 1.
* @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >=
* 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad
* from input dimension `i + 1`, which corresponds to spatial dimension `i`. It
* is required that
* `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0`
*
* This operation is equivalent to the following steps:
*
* 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input
* according to `paddings` to produce `padded` of shape paddedShape.
*
* 2. Reshape `padded` to `reshapedPadded` of shape:
* `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ...,
* paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape`
*
* 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded`
* of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ...,
* paddedShape[M] / blockShape[M-1]] + remainingShape`
*
* 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the
* batch dimension, producing an output tensor of shape:
* `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ...,
* paddedShape[M] / blockShape[M-1]] + remainingShape`
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function spaceToBatchND_(x, blockShape, paddings) {
var $x = convertToTensor(x, 'x', 'spaceToBatchND');
assert($x.rank >= 1 + blockShape.length, function () {
return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length;
});
assert(paddings.length === blockShape.length, function () {
return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length;
});
assert($x.shape.reduce(function (a, b, i) {
if (i > 0 && i <= blockShape.length) {
return a && (b + paddings[i - 1][0] + paddings[i - 1][1]) % blockShape[i - 1] === 0;
}
return a;
}, true), function () {
return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString();
});
var inputs = {
x: $x
};
var attrs = {
blockShape: blockShape,
paddings: paddings
};
return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
}
var spaceToBatchND = op({
spaceToBatchND_: spaceToBatchND_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Performs an N-D pooling operation
*
* @param input The input tensor, of rank 4 or rank 3 of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param windowShape The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param poolingType The type of pooling, either 'max' or 'avg'.
* @param pad The type of padding algorithm:
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function pool_(input, windowShape, poolingType, pad, dilations, strides) {
if (dilations == null) {
dilations = [1, 1];
}
if (strides == null) {
strides = 1;
}
if (pad === 0) {
pad = 'valid';
}
var $x = convertToTensor(input, 'x', 'maxPool');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in pool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad);
var dilation = [convInfo.dilationHeight, convInfo.dilationWidth]; // The following implementation does batchToSpace(pool(spaceToBatch(x)))
// whenever dilation > 1 since the TF kernels do not support dilation > 1.
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037
var basePadding;
if (pad === 'same') {
basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation);
} else {
basePadding = [[0, 0], [0, 0]];
}
var isDilationOne = dilation[0] === 1 && dilation[1] === 1;
var _requiredSpaceToBatch = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding),
adjustedPadding = _requiredSpaceToBatch[0],
adjustedCrops = _requiredSpaceToBatch[1];
var convertedPad = isDilationOne ? pad : 'valid';
var convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding);
var forwardOp = poolingType === 'avg' ? function () {
return avgPool(convertedX, windowShape, strides, convertedPad);
} : function () {
return maxPool(convertedX, windowShape, strides, convertedPad);
};
var y = forwardOp();
var res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
} // Helper function to compute crops and paddings for pool with dilation > 1.
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184
function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) {
var padStart = basePadding.map(function (b) {
return b[0];
});
var origPadEnd = basePadding.map(function (b) {
return b[1];
});
var fullInputShape = inputShape.concat(padStart, origPadEnd);
var padEndExtra = blockShape.map(function (b, i) {
return (b - fullInputShape[i] % b) % b;
});
var padEnd = origPadEnd.map(function (s, i) {
return s + padEndExtra[i];
});
var paddings = blockShape.map(function (_, i) {
return [padStart[i], padEnd[i]];
});
var crops = blockShape.map(function (_, i) {
return [0, padEndExtra[i]];
});
return [paddings, crops];
} // Helper function to compute base paddings for pool with dilation > 1.
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524
function withSpaceToBatchBasePaddings(filterShape, dilation) {
// Spatial dimensions of the filters and the upsampled filters in which we
// introduce (rate - 1) zeros between consecutive filter values.
var dilatedFilterShape = filterShape.map(function (s, i) {
return s + (s - 1) * (dilation[i] - 1);
});
var padExtraShape = dilatedFilterShape.map(function (s) {
return s - 1;
}); // When padding is odd, we pad more at end, following the same
// convention as conv2d.
var padExtraStart = padExtraShape.map(function (s) {
return Math.floor(s / 2);
});
var padExtraEnd = padExtraShape.map(function (s, i) {
return s - padExtraStart[i];
});
return padExtraShape.map(function (_, i) {
return [padExtraStart[i], padExtraEnd[i]];
});
}
var pool = op({
pool_: pool_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the power of one `tf.Tensor` to another. Supports broadcasting.
*
* Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for
* corresponding elements in x and y. The result's dtype will be the upcasted
* type of the `base` and `exp` dtypes.
*
* ```js
* const a = tf.tensor([[2, 3], [4, 5]])
* const b = tf.tensor([[1, 2], [3, 0]]).toInt();
*
* a.pow(b).print(); // or tf.pow(a, b)
* ```
*
* ```js
* const a = tf.tensor([[1, 2], [3, 4]])
* const b = tf.tensor(2).toInt();
*
* a.pow(b).print(); // or tf.pow(a, b)
* ```
* We also expose `powStrict` which has the same signature as this op and
* asserts that `base` and `exp` are the same shape (does not broadcast).
*
* @param base The base `tf.Tensor` to pow element-wise.
* @param exp The exponent `tf.Tensor` to pow element-wise.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function pow_(base, exp) {
var $base = convertToTensor(base, 'base', 'pow');
var $exp = convertToTensor(exp, 'exp', 'pow');
var _makeTypesMatch = makeTypesMatch($base, $exp);
$base = _makeTypesMatch[0];
$exp = _makeTypesMatch[1];
var inputs = {
a: $base,
b: $exp
};
return ENGINE.runKernel(Pow, inputs);
}
var pow$5 = op({
pow_: pow_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes leaky rectified linear element-wise with parametric alphas.
*
* `x < 0 ? alpha * x : f(x) = x`
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
* const alpha = tf.scalar(0.1);
*
* x.prelu(alpha).print(); // or tf.prelu(x, alpha)
* ```
* @param x The input tensor.
* @param alpha Scaling factor for negative values.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function prelu_(x, alpha) {
var $x = convertToTensor(x, 'x', 'prelu');
var $alpha = convertToTensor(alpha, 'alpha', 'prelu');
var inputs = {
x: $x,
alpha: $alpha
};
return ENGINE.runKernel(Prelu, inputs);
}
var prelu = op({
prelu_: prelu_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the product of elements across dimensions of a `tf.Tensor`.
*
* Reduces the input along the dimensions given in `axes`. Unless `keepDims`
* is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in
* `axes`. If `keepDims` is true, the reduced dimensions are retained with
* length 1. If `axes` has no entries, all dimensions are reduced, and a
* `tf.Tensor` with a single element is returned.
*
* ```js
* const x = tf.tensor1d([1, 2, 3]);
*
* x.prod().print(); // or tf.prod(x)
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.prod(axis).print(); // or tf.prod(x, axis)
* ```
*
* @param x The input tensor to compute the product over. If the dtype is `bool`
* it will be converted to `int32` and the output dtype will be `int32`.
* @param axis The dimension(s) to reduce. By default it reduces
* all dimensions.
* @param keepDims If true, retains reduced dimensions with size 1.
*
* @doc {heading: 'Operations', subheading: 'Reduction'}
*/
function prod_(x, axis, keepDims) {
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
var $x = convertToTensor(x, 'x', 'prod');
if ($x.dtype === 'bool') {
// bool is not an allowed type for the underlying kernel.
$x = cast($x, 'int32');
}
var inputs = {
x: $x
};
var attrs = {
axis: axis,
keepDims: keepDims
};
return ENGINE.runKernel(Prod, inputs, attrs);
}
var prod = op({
prod_: prod_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values sampled from a random number generator
* function defined by the user.
*
* @param shape An array of integers defining the output tensor shape.
* @param randFunction A random number generator function which is called
* for each element in the output tensor.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function rand_(shape, randFunction, dtype) {
var size = sizeFromShape(shape);
var values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
} else if (dtype === 'int32') {
values = new Int32Array(size);
} else if (dtype === 'bool') {
values = new Uint8Array(size);
} else {
throw new Error("Unknown data type " + dtype);
}
for (var i = 0; i < size; i++) {
values[i] = randFunction();
}
return ENGINE.makeTensor(values, shape, dtype);
}
var rand = op({
rand_: rand_
});
var alea = createCommonjsModule(function (module) {
// A port of an algorithm by Johannes Baagøe <baagoe@baagoe.com>, 2010
// http://baagoe.com/en/RandomMusings/javascript/
// https://github.com/nquinlan/better-random-numbers-for-javascript-mirror
// Original work is under MIT license -
// Copyright (C) 2010 by Johannes Baagøe <baagoe@baagoe.org>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
(function (global, module, define) {
function Alea(seed) {
var me = this,
mash = Mash();
me.next = function () {
var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32
me.s0 = me.s1;
me.s1 = me.s2;
return me.s2 = t - (me.c = t | 0);
}; // Apply the seeding algorithm from Baagoe.
me.c = 1;
me.s0 = mash(' ');
me.s1 = mash(' ');
me.s2 = mash(' ');
me.s0 -= mash(seed);
if (me.s0 < 0) {
me.s0 += 1;
}
me.s1 -= mash(seed);
if (me.s1 < 0) {
me.s1 += 1;
}
me.s2 -= mash(seed);
if (me.s2 < 0) {
me.s2 += 1;
}
mash = null;
}
function copy(f, t) {
t.c = f.c;
t.s0 = f.s0;
t.s1 = f.s1;
t.s2 = f.s2;
return t;
}
function impl(seed, opts) {
var xg = new Alea(seed),
state = opts && opts.state,
prng = xg.next;
prng.int32 = function () {
return xg.next() * 0x100000000 | 0;
};
prng.double = function () {
return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53
};
prng.quick = prng;
if (state) {
if (typeof state == 'object') copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
function Mash() {
var n = 0xefc8249d;
var mash = function mash(data) {
data = data.toString();
for (var i = 0; i < data.length; i++) {
n += data.charCodeAt(i);
var h = 0.02519603282416938 * n;
n = h >>> 0;
h -= n;
h *= n;
n = h >>> 0;
h -= n;
n += h * 0x100000000; // 2^32
}
return (n >>> 0) * 2.3283064365386963e-10; // 2^-32
};
return mash;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.alea = impl;
}
})(commonjsGlobal, 'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var xor128 = createCommonjsModule(function (module) {
// A Javascript implementaion of the "xor128" prng algorithm by
// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
(function (global, module, define) {
function XorGen(seed) {
var me = this,
strseed = '';
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0; // Set up generator function.
me.next = function () {
var t = me.x ^ me.x << 11;
me.x = me.y;
me.y = me.z;
me.z = me.w;
return me.w ^= me.w >>> 19 ^ t ^ t >>> 8;
};
if (seed === (seed | 0)) {
// Integer seed.
me.x = seed;
} else {
// String seed.
strseed += seed;
} // Mix in string seed, then discard an initial batch of 64 values.
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function prng() {
return (xg.next() >>> 0) / 0x100000000;
};
prng.double = function () {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == 'object') copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.xor128 = impl;
}
})(commonjsGlobal, 'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var xorwow = createCommonjsModule(function (module) {
// A Javascript implementaion of the "xorwow" prng algorithm by
// George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper
(function (global, module, define) {
function XorGen(seed) {
var me = this,
strseed = ''; // Set up generator function.
me.next = function () {
var t = me.x ^ me.x >>> 2;
me.x = me.y;
me.y = me.z;
me.z = me.w;
me.w = me.v;
return (me.d = me.d + 362437 | 0) + (me.v = me.v ^ me.v << 4 ^ (t ^ t << 1)) | 0;
};
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.v = 0;
if (seed === (seed | 0)) {
// Integer seed.
me.x = seed;
} else {
// String seed.
strseed += seed;
} // Mix in string seed, then discard an initial batch of 64 values.
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
if (k == strseed.length) {
me.d = me.x << 10 ^ me.x >>> 4;
}
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
t.v = f.v;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function prng() {
return (xg.next() >>> 0) / 0x100000000;
};
prng.double = function () {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == 'object') copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.xorwow = impl;
}
})(commonjsGlobal, 'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var xorshift7 = createCommonjsModule(function (module) {
// A Javascript implementaion of the "xorshift7" algorithm by
// François Panneton and Pierre L'ecuyer:
// "On the Xorgshift Random Number Generators"
// http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf
(function (global, module, define) {
function XorGen(seed) {
var me = this; // Set up generator function.
me.next = function () {
// Update xor generator.
var X = me.x,
i = me.i,
t,
v,
w;
t = X[i];
t ^= t >>> 7;
v = t ^ t << 24;
t = X[i + 1 & 7];
v ^= t ^ t >>> 10;
t = X[i + 3 & 7];
v ^= t ^ t >>> 3;
t = X[i + 4 & 7];
v ^= t ^ t << 7;
t = X[i + 7 & 7];
t = t ^ t << 13;
v ^= t ^ t << 9;
X[i] = v;
me.i = i + 1 & 7;
return v;
};
function init(me, seed) {
var j,
w,
X = [];
if (seed === (seed | 0)) {
// Seed state array using a 32-bit integer.
w = X[0] = seed;
} else {
// Seed state using a string.
seed = '' + seed;
for (j = 0; j < seed.length; ++j) {
X[j & 7] = X[j & 7] << 15 ^ seed.charCodeAt(j) + X[j + 1 & 7] << 13;
}
} // Enforce an array length of 8, not all zeroes.
while (X.length < 8) {
X.push(0);
}
for (j = 0; j < 8 && X[j] === 0; ++j) {
;
}
if (j == 8) w = X[7] = -1;else w = X[j];
me.x = X;
me.i = 0; // Discard an initial 256 values.
for (j = 256; j > 0; --j) {
me.next();
}
}
init(me, seed);
}
function copy(f, t) {
t.x = f.x.slice();
t.i = f.i;
return t;
}
function impl(seed, opts) {
if (seed == null) seed = +new Date();
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function prng() {
return (xg.next() >>> 0) / 0x100000000;
};
prng.double = function () {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.x) copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.xorshift7 = impl;
}
})(commonjsGlobal, 'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var xor4096 = createCommonjsModule(function (module) {
// A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm.
//
// This fast non-cryptographic random number generator is designed for
// use in Monte-Carlo algorithms. It combines a long-period xorshift
// generator with a Weyl generator, and it passes all common batteries
// of stasticial tests for randomness while consuming only a few nanoseconds
// for each prng generated. For background on the generator, see Brent's
// paper: "Some long-period random number generators using shifts and xors."
// http://arxiv.org/pdf/1004.3115v1.pdf
//
// Usage:
//
// var xor4096 = require('xor4096');
// random = xor4096(1); // Seed with int32 or string.
// assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits.
// assert.equal(random.int32(), 1806534897); // signed int32, 32 bits.
//
// For nonzero numeric keys, this impelementation provides a sequence
// identical to that by Brent's xorgens 3 implementaion in C. This
// implementation also provides for initalizing the generator with
// string seeds, or for saving and restoring the state of the generator.
//
// On Chrome, this prng benchmarks about 2.1 times slower than
// Javascript's built-in Math.random().
(function (global, module, define) {
function XorGen(seed) {
var me = this; // Set up generator function.
me.next = function () {
var w = me.w,
X = me.X,
i = me.i,
t,
v; // Update Weyl generator.
me.w = w = w + 0x61c88647 | 0; // Update xor generator.
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12; // Update Xor generator array state.
v = X[i] = v ^ t;
me.i = i; // Result is the combination.
return v + (w ^ w >>> 16) | 0;
};
function init(me, seed) {
var t,
v,
i,
j,
w,
X = [],
limit = 128;
if (seed === (seed | 0)) {
// Numeric seeds initialize v, which is used to generates X.
v = seed;
seed = null;
} else {
// String seeds are mixed into v and X one character at a time.
seed = seed + '\0';
v = 0;
limit = Math.max(limit, seed.length);
} // Initialize circular array and weyl value.
for (i = 0, j = -32; j < limit; ++j) {
// Put the unicode characters into the array, and shuffle them.
if (seed) v ^= seed.charCodeAt((j + 32) % seed.length); // After 32 shuffles, take v as the starting w value.
if (j === 0) w = v;
v ^= v << 10;
v ^= v >>> 15;
v ^= v << 4;
v ^= v >>> 13;
if (j >= 0) {
w = w + 0x61c88647 | 0; // Weyl.
t = X[j & 127] ^= v + w; // Combine xor and weyl to init array.
i = 0 == t ? i + 1 : 0; // Count zeroes.
}
} // We have detected all zeroes; make the key nonzero.
if (i >= 128) {
X[(seed && seed.length || 0) & 127] = -1;
} // Run the generator 512 times to further mix the state before using it.
// Factoring this as a function slows the main generator, so it is just
// unrolled here. The weyl generator is not advanced while warming up.
i = 127;
for (j = 4 * 128; j > 0; --j) {
v = X[i + 34 & 127];
t = X[i = i + 1 & 127];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
X[i] = v ^ t;
} // Storing state as object members is faster than using closure variables.
me.w = w;
me.X = X;
me.i = i;
}
init(me, seed);
}
function copy(f, t) {
t.i = f.i;
t.w = f.w;
t.X = f.X.slice();
return t;
}
;
function impl(seed, opts) {
if (seed == null) seed = +new Date();
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function prng() {
return (xg.next() >>> 0) / 0x100000000;
};
prng.double = function () {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.X) copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.xor4096 = impl;
}
})(commonjsGlobal, // window object or global
'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var tychei = createCommonjsModule(function (module) {
// A Javascript implementaion of the "Tyche-i" prng algorithm by
// Samuel Neves and Filipe Araujo.
// See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
(function (global, module, define) {
function XorGen(seed) {
var me = this,
strseed = ''; // Set up generator function.
me.next = function () {
var b = me.b,
c = me.c,
d = me.d,
a = me.a;
b = b << 25 ^ b >>> 7 ^ c;
c = c - d | 0;
d = d << 24 ^ d >>> 8 ^ a;
a = a - b | 0;
me.b = b = b << 20 ^ b >>> 12 ^ c;
me.c = c = c - d | 0;
me.d = d << 16 ^ c >>> 16 ^ a;
return me.a = a - b | 0;
};
/* The following is non-inverted tyche, which has better internal
* bit diffusion, but which is about 25% slower than tyche-i in JS.
me.next = function() {
var a = me.a, b = me.b, c = me.c, d = me.d;
a = (me.a + me.b | 0) >>> 0;
d = me.d ^ a; d = d << 16 ^ d >>> 16;
c = me.c + d | 0;
b = me.b ^ c; b = b << 12 ^ d >>> 20;
me.a = a = a + b | 0;
d = d ^ a; me.d = d = d << 8 ^ d >>> 24;
me.c = c = c + d | 0;
b = b ^ c;
return me.b = (b << 7 ^ b >>> 25);
}
*/
me.a = 0;
me.b = 0;
me.c = 2654435769 | 0;
me.d = 1367130551;
if (seed === Math.floor(seed)) {
// Integer seed.
me.a = seed / 0x100000000 | 0;
me.b = seed | 0;
} else {
// String seed.
strseed += seed;
} // Mix in string seed, then discard an initial batch of 64 values.
for (var k = 0; k < strseed.length + 20; k++) {
me.b ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.a = f.a;
t.b = f.b;
t.c = f.c;
t.d = f.d;
return t;
}
;
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function prng() {
return (xg.next() >>> 0) / 0x100000000;
};
prng.double = function () {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof state == 'object') copy(state, xg);
prng.state = function () {
return copy(xg, {});
};
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else if (define && define.amd) {
define(function () {
return impl;
});
} else {
this.tychei = impl;
}
})(commonjsGlobal, 'object' == 'object' && module, // present in node.js
typeof undefined == 'function' && undefined // present with an AMD loader
);
});
var seedrandom = createCommonjsModule(function (module) {
/*
Copyright 2014 David Bau.
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
(function (pool, math) {
//
// The following constants are related to IEEE 754 limits.
//
var global = this,
width = 256,
// each RC4 output is 0 <= x < 256
chunks = 6,
// at least six RC4 outputs for each double
digits = 52,
// there are 52 significant digits in a double
rngname = 'random',
// rngname: name for Math.random and Math.seedrandom
startdenom = math.pow(width, chunks),
significance = math.pow(2, digits),
overflow = significance * 2,
mask = width - 1,
nodecrypto; // node.js crypto module, initialized at the bottom.
//
// seedrandom()
// This is the seedrandom function described above.
//
function seedrandom(seed, options, callback) {
var key = [];
options = options == true ? {
entropy: true
} : options || {}; // Flatten the seed string or build one from local entropy if needed.
var shortseed = mixkey(flatten(options.entropy ? [seed, tostring(pool)] : seed == null ? autoseed() : seed, 3), key); // Use the seed to initialize an ARC4 generator.
var arc4 = new ARC4(key); // This function returns a random double in [0, 1) that contains
// randomness in every bit of the mantissa of the IEEE 754 value.
var prng = function prng() {
var n = arc4.g(chunks),
// Start with a numerator n < 2 ^ 48
d = startdenom,
// and denominator d = 2 ^ 48.
x = 0; // and no 'extra last byte'.
while (n < significance) {
// Fill up all significant digits by
n = (n + x) * width; // shifting numerator and
d *= width; // denominator and generating a
x = arc4.g(1); // new least-significant-byte.
}
while (n >= overflow) {
// To avoid rounding up, before adding
n /= 2; // last byte, shift everything
d /= 2; // right using integer math until
x >>>= 1; // we have exactly the desired bits.
}
return (n + x) / d; // Form the number within [0, 1).
};
prng.int32 = function () {
return arc4.g(4) | 0;
};
prng.quick = function () {
return arc4.g(4) / 0x100000000;
};
prng.double = prng; // Mix the randomness into accumulated entropy.
mixkey(tostring(arc4.S), pool); // Calling convention: what to return as a function of prng, seed, is_math.
return (options.pass || callback || function (prng, seed, is_math_call, state) {
if (state) {
// Load the arc4 state from the given state if it has an S array.
if (state.S) {
copy(state, arc4);
} // Only provide the .state method if requested via options.state.
prng.state = function () {
return copy(arc4, {});
};
} // If called as a method of Math (Math.seedrandom()), mutate
// Math.random because that is how seedrandom.js has worked since v1.0.
if (is_math_call) {
math[rngname] = prng;
return seed;
} // Otherwise, it is a newer calling convention, so return the
// prng directly.
else return prng;
})(prng, shortseed, 'global' in options ? options.global : this == math, options.state);
}
math['seed' + rngname] = seedrandom; //
// ARC4
//
// An ARC4 implementation. The constructor takes a key in the form of
// an array of at most (width) integers that should be 0 <= x < (width).
//
// The g(count) method returns a pseudorandom integer that concatenates
// the next (count) outputs from ARC4. Its return value is a number x
// that is in the range 0 <= x < (width ^ count).
//
function ARC4(key) {
var t,
keylen = key.length,
me = this,
i = 0,
j = me.i = me.j = 0,
s = me.S = []; // The empty key [] is treated as [0].
if (!keylen) {
key = [keylen++];
} // Set up S using the standard key scheduling algorithm.
while (i < width) {
s[i] = i++;
}
for (i = 0; i < width; i++) {
s[i] = s[j = mask & j + key[i % keylen] + (t = s[i])];
s[j] = t;
} // The "g" method returns the next (count) outputs as one number.
(me.g = function (count) {
// Using instance members instead of closure state nearly doubles speed.
var t,
r = 0,
i = me.i,
j = me.j,
s = me.S;
while (count--) {
t = s[i = mask & i + 1];
r = r * width + s[mask & (s[i] = s[j = mask & j + t]) + (s[j] = t)];
}
me.i = i;
me.j = j;
return r; // For robust unpredictability, the function call below automatically
// discards an initial batch of values. This is called RC4-drop[256].
// See http://google.com/search?q=rsa+fluhrer+response&btnI
})(width);
} //
// copy()
// Copies internal state of ARC4 to or from a plain object.
//
function copy(f, t) {
t.i = f.i;
t.j = f.j;
t.S = f.S.slice();
return t;
}
; //
// flatten()
// Converts an object tree to nested arrays of strings.
//
function flatten(obj, depth) {
var result = [],
typ = typeof obj,
prop;
if (depth && typ == 'object') {
for (prop in obj) {
try {
result.push(flatten(obj[prop], depth - 1));
} catch (e) {}
}
}
return result.length ? result : typ == 'string' ? obj : obj + '\0';
} //
// mixkey()
// Mixes a string seed into a key that is an array of integers, and
// returns a shortened string seed that is equivalent to the result key.
//
function mixkey(seed, key) {
var stringseed = seed + '',
smear,
j = 0;
while (j < stringseed.length) {
key[mask & j] = mask & (smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++);
}
return tostring(key);
} //
// autoseed()
// Returns an object for autoseeding, using window.crypto and Node crypto
// module if available.
//
function autoseed() {
try {
var out;
if (nodecrypto && (out = nodecrypto.randomBytes)) {
// The use of 'out' to remember randomBytes makes tight minified code.
out = out(width);
} else {
out = new Uint8Array(width);
(global.crypto || global.msCrypto).getRandomValues(out);
}
return tostring(out);
} catch (e) {
var browser = global.navigator,
plugins = browser && browser.plugins;
return [+new Date(), global, plugins, global.screen, tostring(pool)];
}
} //
// tostring()
// Converts an array of charcodes to a string
//
function tostring(a) {
return String.fromCharCode.apply(0, a);
} //
// When seedrandom.js is loaded, we immediately mix a few bits
// from the built-in RNG into the entropy pool. Because we do
// not want to interfere with deterministic PRNG state later,
// seedrandom will not call math.random on its own again after
// initialization.
//
mixkey(math.random(), pool); //
// Nodejs and AMD support: export the implementation as a module using
// either convention.
//
if ('object' == 'object' && module.exports) {
module.exports = seedrandom; // When in node.js, try using crypto package for autoseeding.
try {
nodecrypto = require('crypto');
} catch (ex) {}
} else if (typeof undefined == 'function' && undefined.amd) {
undefined(function () {
return seedrandom;
});
} // End anonymous scope, and pass initial values.
})([], // pool: entropy pool starts empty
Math // math: package containing random, pow, and seedrandom
);
});
//
// Usage:
//
// var seedrandom = require('seedrandom');
// var random = seedrandom(1); // or any seed.
// var x = random(); // 0 <= x < 1. Every bit is random.
// var x = random.quick(); // 0 <= x < 1. 32 bits of randomness.
// alea, a 53-bit multiply-with-carry generator by Johannes Baagøe.
// Period: ~2^116
// Reported to pass all BigCrush tests.
// xor128, a pure xor-shift generator by George Marsaglia.
// Period: 2^128-1.
// Reported to fail: MatrixRank and LinearComp.
// xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl.
// Period: 2^192-2^32
// Reported to fail: CollisionOver, SimpPoker, and LinearComp.
// xorshift7, by François Panneton and Pierre L'ecuyer, takes
// a different approach: it adds robustness by allowing more shifts
// than Marsaglia's original three. It is a 7-shift generator
// with 256 bits, that passes BigCrush with no systmatic failures.
// Period 2^256-1.
// No systematic BigCrush failures reported.
// xor4096, by Richard Brent, is a 4096-bit xor-shift with a
// very long period that also adds a Weyl generator. It also passes
// BigCrush with no systematic failures. Its long period may
// be useful if you have many generators and need to avoid
// collisions.
// Period: 2^4128-2^32.
// No systematic BigCrush failures reported.
// Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random
// number generator derived from ChaCha, a modern stream cipher.
// https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf
// Period: ~2^127
// No systematic BigCrush failures reported.
// The original ARC4-based prng included in this library.
// Period: ~2^1600
seedrandom.alea = alea;
seedrandom.xor128 = xor128;
seedrandom.xorwow = xorwow;
seedrandom.xorshift7 = xorshift7;
seedrandom.xor4096 = xor4096;
seedrandom.tychei = tychei;
var seedrandom$1 = seedrandom;
var seedrandom_1 = seedrandom$1.alea;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MPRandGauss = /*#__PURE__*/function () {
function MPRandGauss(mean, stdDeviation, dtype, truncated, seed) {
this.mean = mean;
this.stdDev = stdDeviation;
this.dtype = dtype;
this.nextVal = NaN;
this.truncated = truncated;
if (this.truncated) {
this.upper = this.mean + this.stdDev * 2;
this.lower = this.mean - this.stdDev * 2;
}
var seedValue = seed ? seed : Math.random();
this.random = seedrandom_1(seedValue.toString());
}
/** Returns next sample from a Gaussian distribution. */
var _proto = MPRandGauss.prototype;
_proto.nextValue = function nextValue() {
if (!isNaN(this.nextVal)) {
var value = this.nextVal;
this.nextVal = NaN;
return value;
}
var resultX, resultY;
var isValid = false;
while (!isValid) {
var v1 = void 0,
v2 = void 0,
s = void 0;
do {
v1 = 2 * this.random() - 1;
v2 = 2 * this.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s >= 1 || s === 0);
var mul = Math.sqrt(-2.0 * Math.log(s) / s);
resultX = this.mean + this.stdDev * v1 * mul;
resultY = this.mean + this.stdDev * v2 * mul;
if (!this.truncated || this.isValidTruncated(resultX)) {
isValid = true;
}
}
if (!this.truncated || this.isValidTruncated(resultY)) {
this.nextVal = this.convertValue(resultY);
}
return this.convertValue(resultX);
}
/** Handles proper rounding for non-floating-point numbers. */
;
_proto.convertValue = function convertValue(value) {
if (this.dtype == null || this.dtype === 'float32') {
return value;
}
return Math.round(value);
}
/** Returns true if less than 2-standard-deviations from the mean. */
;
_proto.isValidTruncated = function isValidTruncated(value) {
return value <= this.upper && value >= this.lower;
};
return MPRandGauss;
}(); // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating
// Gamma Variables."
var RandGamma = /*#__PURE__*/function () {
function RandGamma(alpha, beta, dtype, seed) {
this.alpha = alpha;
this.beta = 1 / beta; // convert rate to scale parameter
this.dtype = dtype;
var seedValue = seed ? seed : Math.random();
this.randu = seedrandom_1(seedValue.toString());
this.randn = new MPRandGauss(0, 1, dtype, false, this.randu());
if (alpha < 1) {
this.d = alpha + 2 / 3;
} else {
this.d = alpha - 1 / 3;
}
this.c = 1 / Math.sqrt(9 * this.d);
}
/** Returns next sample from a gamma distribution. */
var _proto2 = RandGamma.prototype;
_proto2.nextValue = function nextValue() {
var x2, v0, v1, x, u, v;
while (true) {
do {
x = this.randn.nextValue();
v = 1 + this.c * x;
} while (v <= 0);
v *= v * v;
x2 = x * x;
v0 = 1 - 0.331 * x2 * x2;
v1 = 0.5 * x2 + this.d * (1 - v + Math.log(v));
u = this.randu();
if (u < v0 || Math.log(u) < v1) {
break;
}
}
v = 1 / this.beta * this.d * v;
if (this.alpha < 1) {
v *= Math.pow(this.randu(), 1 / this.alpha);
}
return this.convertValue(v);
}
/** Handles proper rounding for non-floating-point numbers. */
;
_proto2.convertValue = function convertValue(value) {
if (this.dtype === 'float32') {
return value;
}
return Math.round(value);
};
return RandGamma;
}();
var UniformRandom = /*#__PURE__*/function () {
function UniformRandom(min, max, dtype, seed) {
var _this = this;
if (min === void 0) {
min = 0;
}
if (max === void 0) {
max = 1;
}
/** Handles proper rounding for non floating point numbers. */
this.canReturnFloat = function () {
return _this.dtype == null || _this.dtype === 'float32';
};
this.min = min;
this.range = max - min;
this.dtype = dtype;
if (seed == null) {
seed = Math.random();
}
if (typeof seed === 'number') {
seed = seed.toString();
}
if (!this.canReturnFloat() && this.range <= 1) {
throw new Error("The difference between " + min + " - " + max + " <= 1 and dtype is not float");
}
this.random = seedrandom_1(seed);
}
var _proto3 = UniformRandom.prototype;
_proto3.convertValue = function convertValue(value) {
if (this.canReturnFloat()) {
return value;
}
return Math.round(value);
};
_proto3.nextValue = function nextValue() {
return this.convertValue(this.min + this.range * this.random());
};
return UniformRandom;
}();
function jarqueBeraNormalityTest(values) {
// https://en.wikipedia.org/wiki/Jarque%E2%80%93Bera_test
var n = values.length;
var s = skewness(values);
var k = kurtosis(values);
var jb = n / 6 * (Math.pow(s, 2) + 0.25 * Math.pow(k - 3, 2)); // JB test requires 2-degress of freedom from Chi-Square @ 0.95:
// http://www.itl.nist.gov/div898/handbook/eda/section3/eda3674.htm
var CHI_SQUARE_2DEG = 5.991;
if (jb > CHI_SQUARE_2DEG) {
throw new Error("Invalid p-value for JB: " + jb);
}
}
function expectArrayInMeanStdRange(actual, expectedMean, expectedStdDev, epsilon) {
if (epsilon == null) {
epsilon = testEpsilon();
}
var actualMean = mean$1(actual);
expectNumbersClose(actualMean, expectedMean, epsilon);
expectNumbersClose(standardDeviation(actual, actualMean), expectedStdDev, epsilon);
}
function mean$1(values) {
var sum = 0;
for (var i = 0; i < values.length; i++) {
sum += values[i];
}
return sum / values.length;
}
function standardDeviation(values, mean) {
var squareDiffSum = 0;
for (var i = 0; i < values.length; i++) {
var diff = values[i] - mean;
squareDiffSum += diff * diff;
}
return Math.sqrt(squareDiffSum / values.length);
}
function kurtosis(values) {
// https://en.wikipedia.org/wiki/Kurtosis
var valuesMean = mean$1(values);
var n = values.length;
var sum2 = 0;
var sum4 = 0;
for (var i = 0; i < n; i++) {
var v = values[i] - valuesMean;
sum2 += Math.pow(v, 2);
sum4 += Math.pow(v, 4);
}
return 1 / n * sum4 / Math.pow(1 / n * sum2, 2);
}
function skewness(values) {
// https://en.wikipedia.org/wiki/Skewness
var valuesMean = mean$1(values);
var n = values.length;
var sum2 = 0;
var sum3 = 0;
for (var i = 0; i < n; i++) {
var v = values[i] - valuesMean;
sum2 += Math.pow(v, 2);
sum3 += Math.pow(v, 3);
}
return 1 / n * sum3 / Math.pow(1 / (n - 1) * sum2, 3 / 2);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values sampled from a gamma distribution.
*
* ```js
* tf.randomGamma([2, 2], 1).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param alpha The shape parameter of the gamma distribution.
* @param beta The inverse scale parameter of the gamma distribution. Defaults
* to 1.
* @param dtype The data type of the output. Defaults to float32.
* @param seed The seed for the random number generator.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function randomGamma_(shape, alpha, beta, dtype, seed) {
if (beta === void 0) {
beta = 1;
}
if (dtype === void 0) {
dtype = 'float32';
}
if (beta == null) {
beta = 1;
}
if (dtype == null) {
dtype = 'float32';
}
if (dtype !== 'float32' && dtype !== 'int32') {
throw new Error("Unsupported data type " + dtype);
}
var rgamma = new RandGamma(alpha, beta, dtype, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = rgamma.nextValue();
}
return res.toTensor();
}
var randomGamma = op({
randomGamma_: randomGamma_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values sampled from a normal distribution.
*
* ```js
* tf.randomNormal([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output.
* @param seed The seed for the random number generator.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function randomNormal_(shape, mean, stdDev, dtype, seed) {
if (mean === void 0) {
mean = 0;
}
if (stdDev === void 0) {
stdDev = 1;
}
if (dtype != null && dtype === 'bool') {
throw new Error("Unsupported data type " + dtype);
}
var randGauss = new MPRandGauss(mean, stdDev, dtype, false
/* truncated */
, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
var randomNormal = op({
randomNormal_: randomNormal_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values sampled from a uniform distribution.
*
* The generated values follow a uniform distribution in the range [minval,
* maxval). The lower bound minval is included in the range, while the upper
* bound maxval is excluded.
*
* ```js
* tf.randomUniform([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param minval The lower bound on the range of random values to generate.
* Defaults to 0.
* @param maxval The upper bound on the range of random values to generate.
* Defaults to 1.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function randomUniform_(shape, minval, maxval, dtype, seed) {
if (minval === void 0) {
minval = 0;
}
if (maxval === void 0) {
maxval = 1;
}
if (dtype === void 0) {
dtype = 'float32';
}
var res = buffer(shape, dtype);
var random = new UniformRandom(minval, maxval, null, seed);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = random.nextValue();
}
return res.toTensor();
}
var randomUniform = op({
randomUniform_: randomUniform_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a new `tf.Tensor1D` filled with the numbers in the range provided.
*
* The tensor is a is half-open interval meaning it includes start, but
* excludes stop. Decrementing ranges and negative step values are also
* supported.sv
*
*
* ```js
* tf.range(0, 9, 2).print();
* ```
*
* @param start An integer start value
* @param stop An integer stop value
* @param step An integer increment (will default to 1 or -1)
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function range(start, stop, step, dtype) {
if (step === void 0) {
step = 1;
}
if (dtype === void 0) {
dtype = 'float32';
}
if (step === 0) {
throw new Error('Cannot have a step of zero');
}
var attrs = {
start: start,
stop: stop,
step: step,
dtype: dtype
};
return ENGINE.runKernel(Range, {}
/* inputs */
, attrs);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns the real part of a complex (or real) tensor.
*
* Given a tensor input, this operation returns a tensor of type float that is
* the real part of each element in input considered as a complex number.
*
* If the input is real, it simply makes a clone.
*
* ```js
* const x = tf.complex([-2.25, 3.25], [4.75, 5.75]);
* tf.real(x).print();
* ```
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function real_(input) {
var $input = convertToTensor(input, 'input', 'real');
var inputs = {
input: $input
};
return ENGINE.runKernel(Real, inputs);
}
var real = op({
real_: real_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes reciprocal of x element-wise: `1 / x`
*
* ```js
* const x = tf.tensor1d([0, 1, 2]);
*
* x.reciprocal().print(); // or tf.reciprocal(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function reciprocal_(x) {
var $x = convertToTensor(x, 'x', 'reciprocal');
var inputs = {
x: $x
};
return ENGINE.runKernel(Reciprocal, inputs);
}
var reciprocal = op({
reciprocal_: reciprocal_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes rectified linear element-wise: `max(x, 0)`.
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
*
* x.relu().print(); // or tf.relu(x)
* ```
* @param x The input tensor. If the dtype is `bool`, the output dtype will be
* `int32'.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function relu_(x) {
var $x = convertToTensor(x, 'x', 'relu');
var inputs = {
x: $x
};
return ENGINE.runKernel(Relu, inputs);
}
var relu = op({
relu_: relu_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`.
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 8]);
*
* x.relu6().print(); // or tf.relu6(x)
* ```
* @param x The input tensor. If the dtype is `bool`, the output dtype will be
* `int32'.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function relu6_(x) {
var $x = convertToTensor(x, 'x', 'relu6');
var inputs = {
x: $x
};
return ENGINE.runKernel(Relu6, inputs);
}
var relu6 = op({
relu6_: relu6_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reverses a `tf.Tensor` along a specified axis.
*
* Also available are stricter rank-specific methods that assert that `x` is
* of the given rank:
* - `tf.reverse1d`
* - `tf.reverse2d`
* - `tf.reverse3d`
* - `tf.reverse4d`
*
* Except `tf.reverse1d` (which does not have axis param), all methods have
* same signature as this method.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.reverse().print();
* ```
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* const axis = 1;
* x.reverse(axis).print();
* ```
* @param x The input tensor to be reversed.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function reverse_(x, axis) {
var $x = convertToTensor(x, 'x', 'reverse');
var inputs = {
x: $x
};
var attrs = {
dims: axis
};
return ENGINE.runKernel(Reverse, inputs, attrs);
}
var reverse = op({
reverse_: reverse_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reverses a `tf.Tensor1D`.
*
* @param x The input tensor.
*/
function reverse1d_(x) {
var $x = convertToTensor(x, 'x', 'reverse');
assert($x.rank === 1, function () {
return "Error in reverse1D: x must be rank 1 but got rank " + $x.rank + ".";
});
return reverse($x, 0);
}
var reverse1d = op({
reverse1d_: reverse1d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reverses a `tf.Tensor2D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse2d_(x, axis) {
var $x = convertToTensor(x, 'x', 'reverse');
assert($x.rank === 2, function () {
return "Error in reverse2D: x must be rank 2 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse2d = op({
reverse2d_: reverse2d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reverses a `tf.Tensor3D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse3d_(x, axis) {
var $x = convertToTensor(x, 'x', 'reverse');
assert($x.rank === 3, function () {
return "Error in reverse3D: x must be rank 3 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse3d = op({
reverse3d_: reverse3d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reverses a `tf.Tensor4D` along a specified axis.
*
* @param x The input tensor.
* @param axis The set of dimensions to reverse. Must be in the
* range [-rank(x), rank(x)). Defaults to all axes.
*/
function reverse4d_(x, axis) {
var $x = convertToTensor(x, 'x', 'reverse');
assert($x.rank === 4, function () {
return "Error in reverse4D: x must be rank 4 but got rank " + $x.rank + ".";
});
return reverse($x, axis);
}
var reverse4d = op({
reverse4d_: reverse4d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes round of input `tf.Tensor` element-wise: `round(x)`.
* It implements banker's rounding.
*
* ```js
* const x = tf.tensor1d([.6, 1.1, -3.3]);
*
* x.round().print(); // or tf.round(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function round_(x) {
var $x = convertToTensor(x, 'x', 'round');
var inputs = {
x: $x
};
return ENGINE.runKernel(Round, inputs);
}
var round$1 = op({
round_: round_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes reciprocal of square root of the input `tf.Tensor` element-wise:
* `y = 1 / sqrt(x)`
*
* ```js
* const x = tf.tensor1d([1, 2, 4, -1]);
*
* x.rsqrt().print(); // or tf.rsqrt(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function rsqrt_(x) {
var $x = convertToTensor(x, 'x', 'rsqrt');
var inputs = {
x: $x
};
return ENGINE.runKernel(Rsqrt, inputs);
}
var rsqrt = op({
rsqrt_: rsqrt_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.scalar` as it makes the code more readable.
*
* ```js
* tf.scalar(3.14).print();
* ```
*
* @param value The value of the scalar.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function scalar(value, dtype) {
if ((isTypedArray$1(value) && dtype !== 'string' || Array.isArray(value)) && dtype !== 'complex64') {
throw new Error('Error creating a new Scalar: value must be a primitive ' + '(number|boolean|string)');
}
if (dtype === 'string' && isTypedArray$1(value) && !(value instanceof Uint8Array)) {
throw new Error('When making a scalar from encoded string, ' + 'the value must be `Uint8Array`.');
}
var shape = [];
var inferredShape = [];
return makeTensor(value, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes scaled exponential linear element-wise.
*
* `x < 0 ? scale * alpha * (exp(x) - 1) : x`
*
* ```js
* const x = tf.tensor1d([-1, 2, -3, 4]);
*
* x.selu().print(); // or tf.selu(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function selu_(x) {
var $x = convertToTensor(x, 'x', 'selu');
var inputs = {
x: $x
};
return ENGINE.runKernel(Selu, inputs);
}
var selu = op({
selu_: selu_
});
/**
* 2-D convolution with separable filters.
*
* Performs a depthwise convolution that acts separately on channels followed
* by a pointwise convolution that mixes channels. Note that this is
* separability between dimensions [1, 2] and 3, not spatial separability
* between dimensions 1 and 2.
*
* See
* [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d](
* https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d)
* for more details.
*
* @param x The input tensor, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
* assumed.
* @param depthwiseFilter The depthwise filter tensor, rank 4, of shape
* `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is
* the filter used in the first step.
* @param pointwiseFilter The pointwise filter tensor, rank 4, of shape
* `[1, 1, inChannels * channelMultiplier, outChannels]`. This is
* the filter used in the second step.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`. If strides is a single number, then `strideHeight ==
* strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
*
* @doc {heading: 'Operations', subheading: 'Convolution'}
*/
function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
if (dilation === void 0) {
dilation = [1, 1];
}
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
var $x = convertToTensor(x, 'x', 'separableConv2d');
var $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d');
var $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
if (dataFormat === 'NCHW') {
throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' + 'NHWC is supported');
}
assert(x4D.rank === 4, function () {
return "Error in separableConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($depthwiseFilter.rank === 4, function () {
return "Error in separableConv2d: depthwise filter must be rank 4, but " + ("got rank " + $depthwiseFilter.rank + ".");
});
assert($pointwiseFilter.rank === 4, function () {
return "Error in separableConv2d: pointwise filter must be rank 4, but " + ("got rank " + $depthwiseFilter.rank + ".");
});
assert($pointwiseFilter.shape[0] === 1, function () {
return "Error in separableConv2d: the first dimension of pointwise filter " + (" must be 1, but got " + $pointwiseFilter.shape[0] + ".");
});
assert($pointwiseFilter.shape[1] === 1, function () {
return "Error in separableConv2d: the second dimension of pointwise " + ("filter must be 1, but got " + $pointwiseFilter.shape[1] + ".");
});
var inChannels = $depthwiseFilter.shape[2];
var channelMultiplier = $depthwiseFilter.shape[3];
assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, function () {
return "Error in separableConv2d: the third dimension of pointwise filter " + ("must be " + inChannels * channelMultiplier + ", ") + ("but got " + $pointwiseFilter.shape[2] + ".");
});
var depthwise = depthwiseConv2d(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation);
var pointwiseStride = 1;
var res = conv2d(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var separableConv2d = op({
separableConv2d_: separableConv2d_
});
/**
* Computes the difference between two lists of numbers.
*
* Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out`
* that represents all values that are in `x` but not in `y`. The returned
* Tensor `out` is sorted in the same order that the numbers appear in `x`
* (duplicates are preserved). This operation also returns a Tensor indices that
* represents the position of each out element in `x`. In other words:
*
* `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]`
*
* ```js
* const x = [1, 2, 3, 4, 5, 6];
* const y = [1, 3, 5];
*
* const [out, indices] = await tf.setdiff1dAsync(x, y);
* out.print(); // [2, 4, 6]
* indices.print(); // [1, 3, 5]
* ```
*
* @param x 1-D Tensor. Values to keep.
* @param y 1-D Tensor. Must have the same type as x. Values to exclude in the
* output.
* @returns Promise of Tensor tuple [out, indices].
* out: Tensor with the same type as x.
* indices: A Tensor of type int32.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function setdiff1dAsync_(_x, _x2) {
return _setdiff1dAsync_.apply(this, arguments);
}
function _setdiff1dAsync_() {
_setdiff1dAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(x, y) {
var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, _i, p;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
$x = convertToTensor(x, 'x', 'setdiff1d');
$y = convertToTensor(y, 'y', 'setdiff1d');
assert($x.dtype === $y.dtype, function () {
return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ").";
});
assert($x.rank === 1, function () {
return "x should be 1D tensor, but got x (" + $x.shape + ").";
});
assert($y.rank === 1, function () {
return "y should be 1D tensor, but got y (" + $y.shape + ").";
});
_context.next = 7;
return $x.data();
case 7:
xVals = _context.sent;
_context.next = 10;
return $y.data();
case 10:
yVals = _context.sent;
ySet = new Set(yVals);
outputSize = 0;
for (i = 0; i < xVals.length; i++) {
if (!ySet.has(xVals[i])) {
outputSize++;
}
}
buffer = new TensorBuffer([outputSize], $x.dtype);
indices = new TensorBuffer([outputSize], 'int32');
for (_i = 0, p = 0; _i < xVals.length; _i++) {
if (!ySet.has(xVals[_i])) {
buffer.values[p] = xVals[_i];
indices.values[p] = _i;
p++;
}
}
return _context.abrupt("return", [buffer.toTensor(), indices.toTensor()]);
case 18:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _setdiff1dAsync_.apply(this, arguments);
}
var setdiff1dAsync = setdiff1dAsync_;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns an element-wise indication of the sign of a number.
*
* ```js
* const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]);
*
* x.sign().print(); // or tf.sign(x)
* ```
* @param x The input Tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function sign_(x) {
var $x = convertToTensor(x, 'x', 'sign');
var inputs = {
x: $x
};
return ENGINE.runKernel(Sign, inputs);
}
var sign = op({
sign_: sign_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes sin of the input Tensor element-wise: `sin(x)`
*
* ```js
* const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
*
* x.sin().print(); // or tf.sin(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function sin_(x) {
var $x = convertToTensor(x, 'x', 'sin');
var inputs = {
x: $x
};
return ENGINE.runKernel(Sin, inputs);
}
var sin = op({
sin_: sin_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)`
*
* ```js
* const x = tf.tensor1d([0, 1, -1, .7]);
*
* x.sinh().print(); // or tf.sinh(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function sinh_(x) {
var $x = convertToTensor(x, 'x', 'sinh');
var inputs = {
x: $x
};
return ENGINE.runKernel(Sinh, inputs);
}
var sinh = op({
sinh_: sinh_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a 1D slice from 1D array starting at coordinates `begin` and is
* of length `size`. See `slice` for details.
*/
function slice1d_(x, begin, size) {
var $x = convertToTensor(x, 'x', 'slice1d');
assert($x.rank === 1, function () {
return "slice1d expects a rank-1 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice$2($x, [begin], [size]);
}
var slice1d = op({
slice1d_: slice1d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a 2D slice from a 2D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice2d_(x, begin, size) {
var $x = convertToTensor(x, 'x', 'slice2d');
assert($x.rank === 2, function () {
return "slice2d expects a rank-2 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice$2($x, begin, size);
}
var slice2d = op({
slice2d_: slice2d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a 3D slice from a 3D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice3d_(x, begin, size) {
var $x = convertToTensor(x, 'x', 'slice3d');
assert($x.rank === 3, function () {
return "slice3d expects a rank-3 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice$2($x, begin, size);
}
var slice3d = op({
slice3d_: slice3d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a 4D slice from a 4D array starting at coordinates `begin` and
* is of size `size`. See `slice` for details.
*/
function slice4d_(x, begin, size) {
var $x = convertToTensor(x, 'x', 'slice4d');
assert($x.rank === 4, function () {
return "slice4d expects a rank-4 tensor, but got a rank-" + $x.rank + " tensor";
});
return slice$2($x, begin, size);
}
var slice4d = op({
slice4d_: slice4d_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the softmax normalized vector given the logits.
*
* ```js
* const a = tf.tensor1d([1, 2, 3]);
*
* a.softmax().print(); // or tf.softmax(a)
* ```
*
* ```js
* const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]);
*
* a.softmax().print(); // or tf.softmax(a)
* ```
*
* @param logits The logits array.
* @param dim The dimension softmax would be performed on. Defaults to `-1`
* which indicates the last dimension.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function softmax_(logits, dim) {
if (dim === void 0) {
dim = -1;
}
var $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
if (dim === -1) {
dim = $logits.rank - 1;
}
if (dim !== $logits.rank - 1) {
throw Error('Softmax along a non-last dimension is not yet supported. ' + ("Logits was rank " + $logits.rank + " and dim was " + dim));
}
var inputs = {
logits: $logits
};
var attrs = {
dim: dim
};
return ENGINE.runKernel(Softmax, inputs, attrs);
}
var softmax = op({
softmax_: softmax_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Fast Fourier transform.
*
* Computes the 1-dimensional discrete Fourier transform over the inner-most
* dimension of input.
*
* ```js
* const real = tf.tensor1d([1, 2, 3]);
* const imag = tf.tensor1d([1, 2, 3]);
* const x = tf.complex(real, imag);
*
* x.fft().print(); // tf.spectral.fft(x).print();
* ```
* @param input The complex input to compute an fft over.
*
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function fft_(input) {
assert(input.dtype === 'complex64', function () {
return "The dtype for tf.spectral.fft() must be complex64 " + ("but got " + input.dtype + ".");
});
var inputs = {
input: input
};
return ENGINE.runKernel(FFT, inputs);
}
var fft = op({
fft_: fft_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Inverse fast Fourier transform.
*
* Computes the inverse 1-dimensional discrete Fourier transform over the
* inner-most dimension of input.
*
* ```js
* const real = tf.tensor1d([1, 2, 3]);
* const imag = tf.tensor1d([1, 2, 3]);
* const x = tf.complex(real, imag);
*
* x.ifft().print(); // tf.spectral.ifft(x).print();
* ```
* @param input The complex input to compute an ifft over.
*
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function ifft_(input) {
assert(input.dtype === 'complex64', function () {
return "The dtype for tf.spectral.ifft() must be complex64 " + ("but got " + input.dtype + ".");
});
var inputs = {
input: input
};
return ENGINE.runKernel(IFFT, inputs);
}
var ifft = op({
ifft_: ifft_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Inversed real value input fast Fourier transform.
*
* Computes the 1-dimensional inversed discrete Fourier transform over the
* inner-most dimension of the real input.
*
* ```js
* const real = tf.tensor1d([1, 2, 3]);
* const imag = tf.tensor1d([0, 0, 0]);
* const x = tf.complex(real, imag);
*
* x.irfft().print();
* ```
* @param input The real value input to compute an irfft over.
*
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function irfft_(input) {
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var ret;
if (innerDimensionSize <= 2) {
var complexInput = reshape(input, [batch, innerDimensionSize]);
ret = ifft(complexInput);
} else {
// The length of unique components of the DFT of a real-valued signal
// is 2 * (input_len - 1)
var outputShape = [batch, 2 * (innerDimensionSize - 1)];
var realInput = reshape(real(input), [batch, innerDimensionSize]);
var imagInput = reshape(imag(input), [batch, innerDimensionSize]);
var realConjugate = reverse(slice$2(realInput, [0, 1], [batch, innerDimensionSize - 2]), 1);
var imagConjugate = mul(reverse(slice$2(imagInput, [0, 1], [batch, innerDimensionSize - 2]), 1), scalar(-1));
var r = concat([realInput, realConjugate], 1);
var i = concat([imagInput, imagConjugate], 1);
var _complexInput = reshape(complex(r, i), [outputShape[0], outputShape[1]]);
ret = ifft(_complexInput);
}
ret = real(ret); // reshape the result if the input is 3D tensor.
if (input.rank === 3 && input.shape[0] !== 0) {
var temp = ret;
var _batch = input.shape[0];
ret = reshape(ret, [_batch, ret.shape[0] / _batch, ret.shape[1]]);
temp.dispose();
}
return ret;
}
var irfft = op({
irfft_: irfft_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Splits a `tf.Tensor` into sub tensors.
*
* If `numOrSizeSplits` is a number, splits `x` along dimension `axis`
* into `numOrSizeSplits` smaller tensors.
* Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`.
*
* If `numOrSizeSplits` is a number array, splits `x` into
* `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the
* same size as `x` except along dimension `axis` where the size is
* `numOrSizeSplits[i]`.
*
* ```js
* const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]);
* const [a, b] = tf.split(x, 2, 1);
* a.print();
* b.print();
*
* const [c, d, e] = tf.split(x, [1, 2, 1], 1);
* c.print();
* d.print();
* e.print();
* ```
*
* @param x The input tensor to split.
* @param numOrSizeSplits Either an integer indicating the number of
* splits along the axis or an array of integers containing the sizes of
* each output tensor along the axis. If a number then it must evenly divide
* `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`.
* Can contain one -1 indicating that dimension is to be inferred.
* @param axis The dimension along which to split. Defaults to 0 (the first
* dim).
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function split_(x, numOrSizeSplits, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'split');
var inputs = {
x: $x
};
var attr = {
numOrSizeSplits: numOrSizeSplits,
axis: axis
};
return ENGINE.runKernel(SplitV, inputs, attr);
}
var split$1 = op({
split_: split_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Real value input fast Fourier transform.
*
* Computes the 1-dimensional discrete Fourier transform over the
* inner-most dimension of the real input.
*
* ```js
* const real = tf.tensor1d([1, 2, 3]);
*
* real.rfft().print();
* ```
* @param input The real value input to compute an rfft over.
*
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function rfft_(input, fftLength) {
assert(input.dtype === 'float32', function () {
return "The dtype for rfft() must be real value but got " + input.dtype;
});
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = input.size / innerDimensionSize;
var adjustedInput;
if (fftLength != null && fftLength < innerDimensionSize) {
// Need to crop
var begin = input.shape.map(function (v) {
return 0;
});
var size = input.shape.map(function (v) {
return v;
});
size[input.shape.length - 1] = fftLength;
adjustedInput = slice$2(input, begin, size);
innerDimensionSize = fftLength;
} else if (fftLength != null && fftLength > innerDimensionSize) {
// Need to pad with zeros
var zerosShape = input.shape.map(function (v) {
return v;
});
zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize;
adjustedInput = concat([input, zeros(zerosShape)], input.shape.length - 1);
innerDimensionSize = fftLength;
} else {
adjustedInput = input;
} // Complement the input with zero imaginary numbers.
var zerosInput = zerosLike(adjustedInput);
var complexInput = reshape(complex(adjustedInput, zerosInput), [batch, innerDimensionSize]);
var ret = fft(complexInput); // Exclude complex conjugations. These conjugations are put symmetrically.
var half = Math.floor(innerDimensionSize / 2) + 1;
var realValues = real(ret);
var imagValues = imag(ret);
var realComplexConjugate = split$1(realValues, [half, innerDimensionSize - half], realValues.shape.length - 1);
var imagComplexConjugate = split$1(imagValues, [half, innerDimensionSize - half], imagValues.shape.length - 1);
var outputShape = adjustedInput.shape.slice();
outputShape[adjustedInput.shape.length - 1] = half;
return reshape(complex(realComplexConjugate[0], imagComplexConjugate[0]), outputShape);
}
var rfft = op({
rfft_: rfft_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)`
*
* ```js
* const x = tf.tensor1d([1, 2, 4, -1]);
*
* x.sqrt().print(); // or tf.sqrt(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function sqrt_(x) {
var $x = convertToTensor(x, 'x', 'sqrt');
var inputs = {
x: $x
};
return ENGINE.runKernel(Sqrt, inputs);
}
var sqrt$3 = op({
sqrt_: sqrt_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Returns (a - b) * (a - b) element-wise.
* Supports broadcasting.
*
* ```js
* const a = tf.tensor1d([1, 4, 3, 16]);
* const b = tf.tensor1d([1, 2, 9, 4]);
*
* a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
* ```
*
* ```js
* // Broadcast squared difference a with b.
* const a = tf.tensor1d([2, 4, 6, 8]);
* const b = tf.scalar(5);
*
* a.squaredDifference(b).print(); // or tf.squaredDifference(a, b)
* ```
*
* @param a The first tensor.
* @param b The second tensor. Must have the same type as `a`.
*
* @doc {heading: 'Operations', subheading: 'Arithmetic'}
*/
function squaredDifference_(a, b) {
var $a = convertToTensor(a, 'a', 'squaredDifference');
var $b = convertToTensor(b, 'b', 'squaredDifference');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
assertAndGetBroadcastShape($a.shape, $b.shape);
var inputs = {
a: $a,
b: $b
};
var attrs = {};
return ENGINE.runKernel(SquaredDifference, inputs, attrs);
}
var squaredDifference = op({
squaredDifference_: squaredDifference_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Removes dimensions of size 1 from the shape of a `tf.Tensor`.
*
* ```js
* const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]);
* x.squeeze().print();
* ```
*
* @param x The input tensor to be squeezed.
* @param axis An optional list of numbers. If specified, only
* squeezes the dimensions listed. The dimension index starts at 0. It
* is an error to squeeze a dimension that is not 1.
*
* @doc {heading: 'Tensors', subheading: 'Transformations'}
*/
function squeeze_(x, axis) {
var $x = convertToTensor(x, 'x', 'squeeze');
return reshape($x, squeezeShape($x.shape, axis).newShape);
}
var squeeze = op({
squeeze_: squeeze_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`.
*
* ```js
* const a = tf.tensor1d([1, 2]);
* const b = tf.tensor1d([3, 4]);
* const c = tf.tensor1d([5, 6]);
* tf.stack([a, b, c]).print();
* ```
*
* @param tensors A list of tensor objects with the same shape and dtype.
* @param axis The axis to stack along. Defaults to 0 (the first dim).
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function stack_(tensors, axis) {
if (axis === void 0) {
axis = 0;
}
var $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
assert($tensors.length >= 1, function () {
return 'Pass at least one tensor to tf.stack';
});
if ($tensors.length > 0) {
assert(axis <= $tensors[0].rank, function () {
return 'Axis must be <= rank of the tensor';
});
}
var inputs = $tensors;
var attrs = {
axis: axis
};
return ENGINE.runKernel(Pack, inputs, attrs);
}
var stack = op({
stack_: stack_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x`
*
* ```js
* const x = tf.tensor1d([0, 2, -1, -3]);
*
* x.step(.5).print(); // or tf.step(x, .5)
* ```
* @param x The input tensor.
* @param alpha The gradient when input is negative.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function step_(x, alpha) {
if (alpha === void 0) {
alpha = 0.0;
}
var $x = convertToTensor(x, 'x', 'step');
var inputs = {
x: $x
};
var attrs = {
alpha: alpha
};
return ENGINE.runKernel(Step, inputs, attrs);
}
var step = op({
step_: step_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts a strided slice of a tensor.
*
* Roughly speaking, this op extracts a slice of size (end-begin)/stride from
* the given input tensor (x). Starting at the location specified by begin the
* slice continues by adding stride to the index until all dimensions are not
* less than end. Note that a stride can be negative, which causes a reverse
* slice.
*
* ```js
* const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
* [3, 2, 3]);
* t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]]
* t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3],
* // [4, 4, 4]]]
* t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4],
* // [3, 3, 3]]]
* ```
*
* @param x The tensor to stride slice.
* @param begin The coordinates to start the slice from.
* @param end: The coordinates to end the slice at.
* @param strides: The size of the slice.
* @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored
* and the fullest possible range in that dimension is used instead.
* @param endMask: If the ith bit of endMask is set, end[i] is ignored
* and the fullest possible range in that dimension is used instead.
* @param shrinkAxisMask: a bitmask where bit i implies that
* the ith specification should shrink the dimensionality. begin and end must
* imply a slice of size 1 in the dimension.
*
* @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
*/
function stridedSlice_(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
if (beginMask === void 0) {
beginMask = 0;
}
if (endMask === void 0) {
endMask = 0;
}
if (ellipsisMask === void 0) {
ellipsisMask = 0;
}
if (newAxisMask === void 0) {
newAxisMask = 0;
}
if (shrinkAxisMask === void 0) {
shrinkAxisMask = 0;
}
var $x = convertToTensor(x, 'x', 'stridedSlice', 'string_or_numeric');
var inputs = {
x: $x
};
var attrs = {
begin: begin,
end: end,
strides: strides,
beginMask: beginMask,
endMask: endMask,
ellipsisMask: ellipsisMask,
newAxisMask: newAxisMask,
shrinkAxisMask: shrinkAxisMask
};
return ENGINE.runKernel(StridedSlice, inputs, attrs);
}
var stridedSlice = op({
stridedSlice_: stridedSlice_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes tan of the input `tf.Tensor` element-wise, `tan(x)`
*
* ```js
* const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]);
*
* x.tan().print(); // or tf.tan(x)
* ```
* @param x The input tensor.
*
* @doc {heading: 'Operations', subheading: 'Basic math'}
*/
function tan_(x) {
var $x = convertToTensor(x, 'x', 'tan');
var inputs = {
x: $x
};
return ENGINE.runKernel(Tan, inputs);
}
var tan = op({
tan_: tan_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-1 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor1d` as it makes the code more readable.
*
* ```js
* tf.tensor1d([1, 2, 3]).print();
* ```
*
* @param values The values of the tensor. Can be array of numbers,
* or a `TypedArray`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor1d(values, dtype) {
assertNonNull(values);
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 1) {
throw new Error('tensor1d() requires values to be a flat/TypedArray');
}
var shape = null;
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-2 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor2d` as it makes the code more readable.
*
* ```js
* // Pass a nested array.
* tf.tensor2d([[1, 2], [3, 4]]).print();
* ```
* ```js
* // Pass a flat array and specify a shape.
* tf.tensor2d([1, 2, 3, 4], [2, 2]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`.
* @param shape The shape of the tensor. If not provided, it is inferred from
* `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor2d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 2) {
throw new Error('tensor2d() requires shape to have two numbers');
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor2d() requires shape to be provided when `values` ' + 'are a flat/TypedArray');
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-4 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor4d` as it makes the code more readable.
*
* ```js
* // Pass a nested array.
* tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print();
* ```
* ```js
* // Pass a flat array and specify a shape.
* tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`.
* @param shape The shape of the tensor. Optional. If not provided,
* it is inferred from `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor4d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 4) {
throw new Error('tensor4d() requires shape to have four numbers');
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 4 && inferredShape.length !== 1) {
throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor4d() requires shape to be provided when `values` ' + 'are a flat array');
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-5 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor5d` as it makes the code more readable.
*
* ```js
* // Pass a nested array.
* tf.tensor5d([[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]).print();
* ```
* ```js
* // Pass a flat array and specify a shape.
* tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`.
* @param shape The shape of the tensor. Optional. If not provided,
* it is inferred from `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor5d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 5) {
throw new Error('tensor5d() requires shape to have five numbers');
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 5 && inferredShape.length !== 1) {
throw new Error('tensor5d() requires values to be ' + 'number[][][][][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor5d() requires shape to be provided when `values` ' + 'are a flat array');
}
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates rank-6 `tf.Tensor` with the provided values, shape and dtype.
*
* The same functionality can be achieved with `tf.tensor`, but in general
* we recommend using `tf.tensor6d` as it makes the code more readable.
*
* ```js
* // Pass a nested array.
* tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print();
* ```
* ```js
* // Pass a flat array and specify a shape.
* tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print();
* ```
*
* @param values The values of the tensor. Can be nested array of numbers,
* or a flat array, or a `TypedArray`.
* @param shape The shape of the tensor. Optional. If not provided,
* it is inferred from `values`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function tensor6d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 6) {
throw new Error('tensor6d() requires shape to have six numbers');
}
var inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 6 && inferredShape.length !== 1) {
throw new Error('tensor6d() requires values to be number[][][][][][] or ' + 'flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor6d() requires shape to be provided when `values` ' + 'are a flat array');
}
shape = shape || inferredShape;
return makeTensor(values, shape, inferredShape, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Finds the values and indices of the `k` largest entries along the last
* dimension.
*
* If the input is a vector (rank=1), finds the k largest entries in the vector
* and outputs their values and indices as vectors. Thus values[j] is the j-th
* largest entry in input, and its index is indices[j].
* For higher rank inputs, computes the top k entries along the last dimension.
*
* If two elements are equal, the lower-index element appears first.
*
* ```js
* const a = tf.tensor2d([[1, 5], [4, 3]]);
* const {values, indices} = tf.topk(a);
* values.print();
* indices.print();
* ```
* @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`.
* @param k Number of top elements to look for along the last dimension.
* @param sorted If true, the resulting `k` elements will be sorted by the
* values in descending order.
*
* @doc {heading: 'Operations', subheading: 'Evaluation'}
*/
function topk_(x, k, sorted) {
if (k === void 0) {
k = 1;
}
if (sorted === void 0) {
sorted = true;
}
var $x = convertToTensor(x, 'x', 'topk');
if ($x.rank === 0) {
throw new Error('topk() expects the input to be of rank 1 or higher');
}
var lastDim = $x.shape[$x.shape.length - 1];
if (k < 0) {
throw new Error("'k' passed to topk() must be >= 0 but got " + k);
}
if (k > lastDim) {
throw new Error("'k' passed to topk() must be <= the last dimension (" + lastDim + ") " + ("but got " + k));
}
var inputs = {
x: $x
};
var attrs = {
k: k,
sorted: sorted
};
var _ENGINE$runKernel = ENGINE.runKernel(TopK, inputs, attrs),
values = _ENGINE$runKernel[0],
indices = _ENGINE$runKernel[1];
return {
values: values,
indices: indices
};
}
var topk = op({
topk_: topk_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a `tf.Tensor` with values sampled from a truncated normal
* distribution.
*
* ```js
* tf.truncatedNormal([2, 2]).print();
* ```
*
* The generated values follow a normal distribution with specified mean and
* standard deviation, except that values whose magnitude is more than 2
* standard deviations from the mean are dropped and re-picked.
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output tensor.
* @param seed The seed for the random number generator.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function truncatedNormal_(shape, mean, stdDev, dtype, seed) {
if (mean === void 0) {
mean = 0;
}
if (stdDev === void 0) {
stdDev = 1;
}
if (dtype != null && dtype === 'bool') {
throw new Error("Unsupported data type $ { dtype }");
}
var randGauss = new MPRandGauss(mean, stdDev, dtype, true
/* truncated */
, seed);
var res = buffer(shape, dtype);
for (var i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
var truncatedNormal = op({
truncatedNormal_: truncatedNormal_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Finds unique elements along an axis of a tensor.
*
* It returns a tensor `values` containing all of the unique elements along the
* `axis` of the given tensor `x` in the same order that they occur along the
* `axis` in `x`; `x` does not need to be sorted. It also returns a tensor
* `indices` the same size as the number of the elements in `x` along the `axis`
* dimension. It contains the index in the unique output `values`.
*
* ```js
* // A 1-D tensor
* const a = tf.tensor1d([1, 1, 2, 4, 4, 4, 7, 8, 8]);
* const {values, indices} = tf.unique(a);
* values.print(); // [1, 2, 4, 7, 8,]
* indices.print(); // [0, 0, 1, 2, 2, 2, 3, 4, 4]
* ```
*
* ```js
* // A 2-D tensor with axis=0
* //
* // 'a' is: [[1, 0, 0],
* // [1, 0, 0],
* // [2, 0, 0]]
* const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
* const {values, indices} = tf.unique(a, 0)
* values.print(); // [[1, 0, 0],
* // [2, 0, 0]]
* indices.print(); // [0, 0, 1]
* ```
*
* ```js
* // A 2-D tensor with axis=1
* //
* // 'a' is: [[1, 0, 0],
* // [1, 0, 0],
* // [2, 0, 0]]
* const a = tf.tensor2d([[1, 0, 0], [1, 0, 0], [2, 0, 0]]);
* const {values, indices} = tf.unique(a, 1)
* values.print(); // [[1, 0],
* // [1, 0],
* // [2, 0]]
* indices.print(); // [0, 1, 1]
* ```
* @param x A tensor (int32, string, bool).
* @param axis The axis of the tensor to find the unique elements.
* @returns [uniqueElements, indices] (see above for details)
*
* @doc {heading: 'Operations', subheading: 'Evaluation'}
*/
function unique_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'unique', 'string_or_numeric');
assert($x.rank > 0, function () {
return 'The input tensor must be at least 1D';
});
var inputs = {
x: $x
};
var attrs = {
axis: axis
};
var _ENGINE$runKernel = ENGINE.runKernel(Unique, inputs, attrs),
values = _ENGINE$runKernel[0],
indices = _ENGINE$runKernel[1];
return {
values: values,
indices: indices
};
}
var unique = op({
unique_: unique_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the sum along segments of a `tf.Tensor`.
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
* const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32');
* const numSegments = 3;
*
* x.unsortedSegmentSum(segmentIds, numSegments).print()
* //or tf.unsortedSegmentSum(x, segmentIds, numSegments)
* ```
* @param x The `tf.Tensor` that will be summed along its segments.
* @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s
* dimension along the `axis`. Maps each element of `x` to a segment.
* @param numSegments The number of distinct `segmentIds`.
*
* @doc {heading: 'Operations', subheading: 'Segment'}
*/
function unsortedSegmentSum_(x, segmentIds, numSegments) {
var $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
assert(isInt(numSegments), function () {
return 'numSegments must be of dtype int';
});
var inputs = {
x: $x,
segmentIds: $segmentIds
};
var attrs = {
numSegments: numSegments
};
return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
}
var unsortedSegmentSum = op({
unsortedSegmentSum_: unsortedSegmentSum_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s.
*
* ```js
* const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
*
* tf.unstack(a).forEach(tensor => tensor.print());
* ```
*
* @param x A tensor object.
* @param axis The axis to unstack along. Defaults to 0 (the first dim).
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function unstack_(x, axis) {
if (axis === void 0) {
axis = 0;
}
var $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
assert(axis >= -$x.shape.length && axis < $x.shape.length, function () {
return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")";
});
var inputs = {
value: $x
};
var attrs = {
axis: axis
};
return ENGINE.runKernel(Unpack, inputs, attrs);
}
var unstack = op({
unstack_: unstack_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a new variable with the provided initial value.
* ```js
* const x = tf.variable(tf.tensor([1, 2, 3]));
* x.assign(tf.tensor([4, 5, 6]));
*
* x.print();
* ```
*
* @param initialValue Initial value for the tensor.
* @param trainable If true, optimizers are allowed to update it.
* @param name Name of the variable. Defaults to a unique id.
* @param dtype If set, initialValue will be converted to the given type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function variable(initialValue, trainable, name, dtype) {
if (trainable === void 0) {
trainable = true;
}
return ENGINE.makeVariable(initialValue, trainable, name, dtype);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function whereImpl(condShape, condVals) {
var indices = [];
for (var i = 0; i < condVals.length; i++) {
if (condVals[i]) {
indices.push(i);
}
}
var inBuffer = buffer(condShape, 'int32');
var out = buffer([indices.length, condShape.length], 'int32');
for (var _i = 0; _i < indices.length; _i++) {
var loc = inBuffer.indexToLoc(indices[_i]);
var offset = _i * condShape.length;
out.values.set(loc, offset);
}
return out.toTensor();
}
/**
* Returns the coordinates of true elements of condition.
*
* The coordinates are returned in a 2-D tensor where the first dimension (rows)
* represents the number of true elements, and the second dimension (columns)
* represents the coordinates of the true elements. Keep in mind, the shape of
* the output tensor can vary depending on how many true values there are in
* input. Indices are output in row-major order. The resulting tensor has the
* shape `[numTrueElems, condition.rank]`.
*
* This is analogous to calling the python `tf.where(cond)` without an x or y.
*
* ```js
* const cond = tf.tensor1d([false, false, true], 'bool');
* const result = await tf.whereAsync(cond);
* result.print();
* ```
*
* @doc {heading: 'Operations', subheading: 'Logical'}
*/
function whereAsync_(_x) {
return _whereAsync_.apply(this, arguments);
}
function _whereAsync_() {
_whereAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(condition) {
var $condition, vals, res;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
$condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool');
_context.next = 3;
return $condition.data();
case 3:
vals = _context.sent;
res = whereImpl($condition.shape, vals);
if (condition !== $condition) {
$condition.dispose();
}
return _context.abrupt("return", res);
case 7:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _whereAsync_.apply(this, arguments);
}
var whereAsync = whereAsync_;
/**
* Apply boolean mask to tensor.
*
* ```js
* const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
* const mask = tf.tensor1d([1, 0, 1], 'bool');
* const result = await tf.booleanMaskAsync(tensor, mask);
* result.print();
* ```
*
* @param tensor N-D tensor.
* @param mask K-D boolean tensor, K <= N and K must be known statically.
* @param axis A 0-D int Tensor representing the axis in tensor to mask from.
* By default, axis is 0 which will mask from the first dimension.
* Otherwise K + axis <= N.
*
* @doc {heading: 'Tensors', subheading: 'Slicing and Joining'}
*/
function booleanMaskAsync_(_x, _x2, _x3) {
return _booleanMaskAsync_.apply(this, arguments);
}
function _booleanMaskAsync_() {
_booleanMaskAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(tensor, mask, axis) {
var $tensor, $mask, axisFrom, maskDim, tensorShape, leadingSize, i, targetTensorShape, reshapedTensor, reshapedMask, positivePositions, indices, res;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
$tensor = convertToTensor(tensor, 'tensor', 'boolMask');
$mask = convertToTensor(mask, 'mask', 'boolMask', 'bool');
axisFrom = axis == null ? 0 : axis;
maskDim = $mask.rank;
tensorShape = $tensor.shape;
assert(maskDim > 0, function () {
return 'mask cannot be scalar';
});
assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, "mask's shape must match the first K dimensions of tensor's shape,");
leadingSize = 1;
for (i = axisFrom; i < axisFrom + maskDim; i++) {
leadingSize *= tensorShape[i];
}
targetTensorShape = tensorShape.slice(0, axisFrom).concat([leadingSize], tensorShape.slice(axisFrom + maskDim));
reshapedTensor = reshape($tensor, targetTensorShape);
reshapedMask = reshape($mask, [-1]);
_context.next = 14;
return whereAsync(reshapedMask);
case 14:
positivePositions = _context.sent;
indices = squeeze(positivePositions, [1]);
res = gather(reshapedTensor, indices, axisFrom); // Ensure no memory leak.
if (tensor !== $tensor) {
$tensor.dispose();
}
if (mask !== $mask) {
$mask.dispose();
}
indices.dispose();
reshapedTensor.dispose();
reshapedMask.dispose();
positivePositions.dispose();
return _context.abrupt("return", res);
case 24:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _booleanMaskAsync_.apply(this, arguments);
}
var booleanMaskAsync = booleanMaskAsync_;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the norm of scalar, vectors, and matrices.
* This function can compute several different vector norms (the 1-norm, the
* Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0)
* and matrix norms (Frobenius, 1-norm, and inf-norm).
*
* ```js
* const x = tf.tensor1d([1, 2, 3, 4]);
*
* x.norm().print(); // or tf.norm(x)
* ```
*
* @param x The input array.
* @param ord Optional. Order of the norm. Supported norm types are
* following:
*
* | ord | norm for matrices | norm for vectors
* |------------|---------------------------|---------------------
* |'euclidean' |Frobenius norm |2-norm
* |'fro' |Frobenius norm |
* |Infinity |max(sum(abs(x), axis=1)) |max(abs(x))
* |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x))
* |1 |max(sum(abs(x), axis=0)) |sum(abs(x))
* |2 | |sum(abs(x)^2)^1/2*
*
* @param axis Optional. If axis is null (the default), the input is
* considered a vector and a single vector norm is computed over the entire
* set of values in the Tensor, i.e. norm(x, ord) is equivalent
* to norm(x.reshape([-1]), ord). If axis is a integer, the input
* is considered a batch of vectors, and axis determines the axis in x
* over which to compute vector norms. If axis is a 2-tuple of integer it is
* considered a batch of matrices and axis determines the axes in NDArray
* over which to compute a matrix norm.
* @param keepDims Optional. If true, the norm have the same dimensionality
* as the input.
*
* @doc {heading: 'Operations', subheading: 'Matrices'}
*/
function norm_(x, ord, axis, keepDims) {
if (ord === void 0) {
ord = 'euclidean';
}
if (axis === void 0) {
axis = null;
}
if (keepDims === void 0) {
keepDims = false;
}
x = convertToTensor(x, 'x', 'norm');
var norm = normImpl(x, ord, axis);
var keepDimsShape = norm.shape;
if (keepDims) {
var axes = parseAxisParam(axis, x.shape);
keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
}
return reshape(norm, keepDimsShape);
}
function normImpl(x, p, axis) {
if (axis === void 0) {
axis = null;
}
if (x.rank === 0) {
return abs$8(x);
} // consider vector when no axis is specified
if (x.rank !== 1 && axis === null) {
return normImpl(reshape(x, [-1]), p, axis);
} // vector
if (x.rank === 1 || typeof axis === 'number' || Array.isArray(axis) && axis.length === 1) {
if (p === 1) {
return sum$1(abs$8(x), axis);
}
if (p === Infinity) {
return max$5(abs$8(x), axis);
}
if (p === -Infinity) {
return min$9(abs$8(x), axis);
}
if (p === 'euclidean' || p === 2) {
// norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2
return sqrt$3(sum$1(pow$5(abs$8(x), scalar(2, 'int32')), axis));
}
throw new Error("Error in norm: invalid ord value: " + p);
} // matrix (assumption axis[0] < axis[1])
if (Array.isArray(axis) && axis.length === 2) {
if (p === 1) {
return max$5(sum$1(abs$8(x), axis[0]), axis[1] - 1);
}
if (p === Infinity) {
return max$5(sum$1(abs$8(x), axis[1]), axis[0]);
}
if (p === -Infinity) {
return min$9(sum$1(abs$8(x), axis[1]), axis[0]);
}
if (p === 'fro' || p === 'euclidean') {
// norm(x) = sqrt(sum(pow(x, 2)))
return sqrt$3(sum$1(square(x), axis));
}
throw new Error("Error in norm: invalid ord value: " + p);
}
throw new Error("Error in norm: invalid axis: " + axis);
}
var norm = op({
norm_: norm_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Compute the moving average of a variable.
*
* Without zeroDebias, the moving average operation is defined by:
* `v += delta`
* where
* `delta = (1 - decay) * (x - v)`
*
* With zeroDebias (default), the `delta` term is scaled to debias the
* effect of the (assumed) zero-initialization of `v`.
* `delta /= (1 - decay ^ step)`
*
* For more details on the zero-debiasing algorithm, see:
* https://arxiv.org/abs/1412.6980
*
* Note that this function is completely stateless and does not keep track of
* step count. The step count needs to be maintained by the caller and passed
* in as `step`.
*
* @param v The current moving average value.
* @param x New input value, must have the same shape and dtype as `v`.
* @param decay The decay factor. Typical values are 0.95 and 0.99.
* @param step Step count.
* @param zeroDebias: Whether zeroDebias is to be performed (default: `true`).
* @returns The new moving average value.
*
* @doc {heading: 'Operations', subheading: 'Moving Average'}
*/
function movingAverage_(v, x, decay, step, zeroDebias) {
if (zeroDebias === void 0) {
zeroDebias = true;
}
var $v = convertToTensor(v, 'v', 'movingAverage');
var $x = convertToTensor(x, 'x', 'movingAverage');
var $decay = convertToTensor(decay, 'decay', 'movingAverage');
assertTypesMatch($v, $x);
assert(arraysEqual($v.shape, $x.shape), function () {
return 'Shape mismatch in v and x';
});
var one = scalar(1);
var oneMinusDecay = sub(one, $decay);
var update = mul(sub($x, $v), oneMinusDecay);
if (zeroDebias) {
assert(step != null, function () {
return 'When using zeroDebias: true, step is required.';
});
var $step = convertToTensor(step, 'step', 'movingAverage');
update = div(update, sub(one, pow$5($decay, $step)));
}
return add$1($v, update);
}
var movingAverage = op({
movingAverage_: movingAverage_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates a new tensor by applying sparse updates to individual
* values or slices within a zero tensor of the given shape tensor according to
* indices. This operator is the inverse of the `tf.gatherND` operator which
* extracts values or slices from a given tensor.
*
* ```js
* const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32');
* const updates = tf.tensor1d([9, 10, 11, 12]);
* const shape = [8];
* tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12]
* ```
*
* @param indices The tensor contains the indices into the output tensor.
* @param updates The tensor contains the value for the indices.
* @param shape: The shape of the output tensor.
*
* @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
*/
function scatterND_(indices, updates, shape) {
var $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32');
var $updates = convertToTensor(updates, 'updates', 'scatterND');
validateInput($updates, $indices, shape);
var inputs = {
indices: $indices,
updates: $updates
};
var attrs = {
shape: shape
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(ScatterNd, inputs, attrs);
}
var scatterND = op({
scatterND_: scatterND_
});
/**
* Validate sparseToDense inputs.
*
* @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
* sparseIndices[i] contains the complete index where sparseValues[i] will be
* placed.
* @param sparseValues A 0-D or 1-D Tensor. Values
* corresponding to each row of sparseIndices, or a scalar value to be used for
* all sparse indices.
* @param outputShape number[]. Shape of the dense output tensor.
* @param validateIndices boolean. indice validation is not supported, error
* will be thrown if it is set.
*/
function validateInput$1(sparseIndices, sparseValues, outputShape, defaultValues) {
if (sparseIndices.dtype !== 'int32') {
throw new Error('tf.sparseToDense() expects the indices to be int32 type,' + (" but the dtype was " + sparseIndices.dtype + "."));
}
if (sparseIndices.rank > 2) {
throw new Error('sparseIndices should be a scalar, vector, or matrix,' + (" but got shape " + sparseIndices.shape + "."));
}
var numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1;
var numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1;
if (outputShape.length !== numDims) {
throw new Error('outputShape has incorrect number of elements:,' + (" " + outputShape.length + ", should be: " + numDims + "."));
}
var numValues = sparseValues.size;
if (!(sparseValues.rank === 0 || sparseValues.rank === 1 && numValues === numElems)) {
throw new Error('sparseValues has incorrect shape ' + (sparseValues.shape + ", should be [] or [" + numElems + "]"));
}
if (sparseValues.dtype !== defaultValues.dtype) {
throw new Error('sparseValues.dtype must match defaultValues.dtype');
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a sparse representation into a dense tensor.
*
* Builds an array dense with shape outputShape such that:
*
* // If sparseIndices is scalar
* dense[i] = (i == sparseIndices ? sparseValues : defaultValue)
*
* // If sparseIndices is a vector, then for each i
* dense[sparseIndices[i]] = sparseValues[i]
*
* // If sparseIndices is an n by d matrix, then for each i in [0, n)
* dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i]
* All other values in dense are set to defaultValue. If sparseValues is a
* scalar, all sparse indices are set to this single value.
*
* If indices are repeated the final value is summed over all values for those
* indices.
*
* ```js
* const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32');
* const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32');
* const shape = [8];
* tf.sparseToDense(indices, values, shape).print();
* ```
*
* @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32.
* sparseIndices[i] contains the complete index where sparseValues[i] will be
* placed.
* @param sparseValues A 0-D or 1-D Tensor. Values
* corresponding to each row of sparseIndices, or a scalar value to be used for
* all sparse indices.
* @param outputShape Shape of the dense output tensor. the type is inferred.
* @param defaultValue Scalar. Value to set for indices not specified in
* sparseIndices. Defaults to zero.
*
* @doc {heading: 'Operations', subheading: 'Normalization'}
*/
function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue) {
if (defaultValue === void 0) {
defaultValue = 0;
}
var $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32');
var $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense');
var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype);
validateInput$1($sparseIndices, $sparseValues, outputShape, $defaultValue);
var inputs = {
sparseIndices: $sparseIndices,
sparseValues: $sparseValues,
defaultValue: $defaultValue
};
var attrs = {
outputShape: outputShape
};
return ENGINE.runKernel(SparseToDense, inputs, attrs);
}
var sparseToDense = op({
sparseToDense_: sparseToDense_
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Gather slices from input tensor into a Tensor with shape specified by
* `indices`.
*
* `indices` is an K-dimensional integer tensor, best thought of as a
* (K-1)-dimensional tensor of indices into input, where each element defines a
* slice of input:
* output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]]
*
* Whereas in `tf.gather`, `indices` defines slices into the first dimension of
* input, in `tf.gatherND`, `indices` defines slices into the first N dimensions
* of input, where N = indices.shape[-1].
*
* The last dimension of indices can be at most the rank of input:
* indices.shape[-1] <= input.rank
*
* The last dimension of `indices` corresponds to elements
* (if indices.shape[-1] == input.rank) or slices
* (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of
* input.
* The output tensor has shape
* indices.shape[:-1] + input.shape[indices.shape[-1]:]
*
* Note that on CPU, if an out of bound index is found, an error is returned. On
* GPU, if an out of bound index is found, a 0 is stored in the corresponding
* output value.
*
* ```js
* const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32');
* const input = tf.tensor2d([9, 10, 11, 12], [2, 2]);
* tf.gatherND(input, indices).print() // [10, 11]
* ```
*
* @param x The tensor from which to gather values.
* @param indices Index tensor, must be of type int32.
*
* @doc {heading: 'Operations', subheading: 'Slicing and Joining'}
*/
function gatherND_(x, indices) {
var $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32');
var $x = convertToTensor(x, 'x', 'gatherND', 'string_or_numeric');
var inputs = {
params: $x,
indices: $indices
};
return ENGINE.runKernel(GatherNd, inputs);
}
var gatherND = op({
gatherND_: gatherND_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Normalize noise shape based on provided tensor and noise shape.
*
* @param x Tensor.
* @param noiseShape The shape for the randomly generated keep/drop flags, as
* an array of numbers. Optional.
* @returns Normalized noise shape.
*/
function getNoiseShape(x, noiseShape) {
if (noiseShape == null) {
return x.shape.slice();
}
if (arraysEqual(x.shape, noiseShape)) {
return noiseShape;
}
if (x.shape.length === noiseShape.length) {
var newDimension = [];
for (var i = 0; i < x.shape.length; i++) {
if (noiseShape[i] == null && x.shape[i] != null) {
newDimension.push(x.shape[i]);
} else {
newDimension.push(noiseShape[i]);
}
}
return newDimension;
}
return noiseShape;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes dropout.
*
* ```js
* const x = tf.tensor1d([1, 2, 2, 1]);
* const rate = 0.75;
* const output = tf.dropout(x, rate);
* output.print();
* ```
*
* @param x A floating point Tensor or TensorLike.
* @param rate A float in the range [0, 1). The probability that each element
* of x is discarded.
* @param noiseShape An array of numbers of type int32, representing the
* shape for randomly generated keep/drop flags. If the noiseShape has null
* value, it will be automatically replaced with the x's relative dimension
* size. Optional.
* @param seed Used to create random seeds. Optional.
* @returns A Tensor of the same shape of x.
*
* @doc {heading: 'Operations', subheading: 'Dropout'}
*/
function dropout_(x, rate, noiseShape, seed) {
var $x = convertToTensor(x, 'x', 'dropout');
assert($x.dtype === 'float32', function () {
return "x has to be a floating point tensor since it's going to be " + ("scaled, but got a " + $x.dtype + " tensor instead.");
});
assert(rate >= 0 && rate < 1, function () {
return "rate must be a float in the range [0, 1), but got " + rate + ".";
});
if (rate === 0) {
return x instanceof Tensor ? $x.clone() : $x;
}
var $noiseShape = getNoiseShape($x, noiseShape);
var keepProb = 1 - rate;
var multiplier = div(floor$a(add$1(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
return mul($x, multiplier);
}
var dropout = op({
dropout_: dropout_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function enclosingPowerOfTwo(value) {
// Return 2**N for integer N such that 2**N >= value.
return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0))));
}
function cosineWindow(windowLength, a, b) {
var even = 1 - windowLength % 2;
var newValues = new Float32Array(windowLength);
for (var i = 0; i < windowLength; ++i) {
var cosArg = 2.0 * Math.PI * i / (windowLength + even - 1);
newValues[i] = a - b * Math.cos(cosArg);
}
return tensor1d(newValues, 'float32');
}
/**
* Returns whether the targets are in the top K predictions.
*
* ```js
* const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
* const targets = tf.tensor1d([2, 0]);
* const precision = await tf.inTopKAsync(predictions, targets);
* precision.print();
* ```
* @param predictions 2-D or higher `tf.Tensor` with last dimension being
* at least `k`.
* @param targets 1-D or higher `tf.Tensor`.
* @param k Optional Number of top elements to look at for computing precision,
* default to 1.
*
* @doc {heading: 'Operations', subheading: 'Evaluation'}
*/
function inTopKAsync_(_x, _x2, _x3) {
return _inTopKAsync_.apply(this, arguments);
}
function _inTopKAsync_() {
_inTopKAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(predictions, targets, k) {
var $predictions, $targets, lastDim, predictionsVals, targetsVals, batch, size, precision, b, offset, vals, valAndInd, i, _i;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (k === void 0) {
k = 1;
}
$predictions = convertToTensor(predictions, 'predictions', 'inTopK');
$targets = convertToTensor(targets, 'targets', 'inTopK');
assert($predictions.rank > 1, function () {
return 'inTopK() expects the predictions to be of rank 2 or higher, ' + ("but got " + $predictions.rank);
});
assert($predictions.rank - 1 === $targets.rank, function () {
return "predictions rank should be 1 larger than " + "targets rank, but got predictions rank " + ($predictions.rank + " and targets rank " + $targets.rank);
});
assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, "predictions's shape should be align with the targets' shape, " + 'except the last dimension.');
lastDim = $predictions.shape[$predictions.shape.length - 1];
assert(k > 0 && k <= lastDim, function () {
return "'k' passed to inTopK() must be > 0 && <= the predictions last " + ("dimension (" + lastDim + "), but got " + k);
});
_context.next = 10;
return $predictions.data();
case 10:
predictionsVals = _context.sent;
_context.next = 13;
return $targets.data();
case 13:
targetsVals = _context.sent;
// Reshape predictionsVals into a 2d tensor [batch, lastDim]
// and look up topK along lastDim.
batch = predictionsVals.length / lastDim, size = lastDim;
precision = getTypedArrayFromDType('bool', batch);
b = 0;
case 17:
if (!(b < batch)) {
_context.next = 35;
break;
}
offset = b * size;
vals = predictionsVals.subarray(offset, offset + size);
valAndInd = [];
for (i = 0; i < vals.length; i++) {
valAndInd.push({
value: vals[i],
index: i
});
}
valAndInd.sort(function (a, b) {
return b.value - a.value;
});
precision[b] = 0;
_i = 0;
case 25:
if (!(_i < k)) {
_context.next = 32;
break;
}
if (!(valAndInd[_i].index === targetsVals[b])) {
_context.next = 29;
break;
}
precision[b] = 1;
return _context.abrupt("break", 32);
case 29:
_i++;
_context.next = 25;
break;
case 32:
b++;
_context.next = 17;
break;
case 35:
if (predictions !== $predictions) {
$predictions.dispose();
}
if (targets !== $targets) {
$targets.dispose();
} // Output precision has the same shape as targets.
return _context.abrupt("return", tensor(precision, $targets.shape, 'bool'));
case 38:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _inTopKAsync_.apply(this, arguments);
}
var inTopKAsync = inTopKAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the derivative of the filter of a 2D convolution.
*
* @param x The input tensor, of rank 4 or rank 3 of shape
* [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed.
* @param dy The dy image, of rank 4 or rank 3, of shape
* [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed.
* @param filterShape The shape of the filter, length 4,
* [filterHeight, filterWidth, inDepth, outDepth].
* @param strides The strides of the convolution: [strideHeight,
* strideWidth].
* @param pad A string from: 'same', 'valid'. The type of padding algorithm
* used in the forward prop of the op.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels].
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat, dimRoundingMode) {
if (dataFormat === void 0) {
dataFormat = 'NHWC';
}
var x4D = x;
if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
var dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in conv2dDerFilter: input must be rank 4, but got shape " + (x4D.shape + ".");
});
assert(dy4D.rank === 4, function () {
return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + (dy4D.shape + ".");
});
assert(filterShape.length === 4, function () {
return "Error in conv2dDerFilter: filterShape must be length 4, but got " + (filterShape + ".");
});
var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
assert(inDepth === filterShape[2], function () {
return "Error in conv2dDerFilter: depth of input " + inDepth + ") must " + ("match input depth in filter (" + filterShape[2] + ".");
});
assert(outDepth === filterShape[3], function () {
return "Error in conv2dDerFilter: depth of dy (" + outDepth + ") must " + ("match output depth for filter (" + filterShape[3] + ").");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in conv2dDerFilter: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
x: x4D,
dy: dy4D
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dimRoundingMode: dimRoundingMode,
filterShape: filterShape
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
}
var conv2DBackpropFilter = op({
conv2DBackpropFilter_: conv2DBackpropFilter_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getFusedDyActivation(dy, y, activation) {
if (activation == null || activation === 'linear') {
return dy;
}
if (activation === 'relu') {
return mul(dy, step(y));
}
throw new Error("Cannot compute gradient for fused activation " + activation + ".");
} // Returns gradient for fused bias.
function getFusedBiasGradient(bias, dyActivation) {
var res = dyActivation;
var reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, bias.shape);
}
function applyActivation(x, activation, preluActivationWeights, leakyreluAlpha) {
if (activation === 'linear') {
return x;
} else if (activation === 'relu') {
return relu(x);
} else if (activation === 'elu') {
return elu(x);
} else if (activation === 'relu6') {
return relu6(x);
} else if (activation === 'prelu') {
return prelu(x, preluActivationWeights);
} else if (activation === 'leakyrelu') {
return leakyRelu(x, leakyreluAlpha);
} else if (activation === 'sigmoid') {
return sigmoid(x);
}
throw new Error("Unknown fused activation " + activation + ".");
} // Whether we should call fused ops.
var shouldFuse = function shouldFuse(gradientDepth, activation) {
var gradientMode = gradientDepth > 0;
return !gradientMode || activation === 'linear';
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes a 2D convolution over the input x, optionally fused with adding a
* bias and applying an activation.
*
* ```js
* const inputDepth = 2;
* const inShape = [2, 2, 2, inputDepth];
* const outputDepth = 2;
* const fSize = 1;
* const pad = 0;
* const strides = 1;
*
* const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
* 16], inShape);
* const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth,
* outputDepth]);
*
* tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC',
* dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print();
* ```
*
* @param obj An object with the following properties:
* @param x The input tensor, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
* assumed.
* @param filter The filter, rank 4, of shape
* `[filterHeight, filterWidth, inDepth, outDepth]`.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid` output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
* @param bias Tensor to be added to the result.
* @param activation Name of activation kernel (defaults to `linear`) to be
* applied
* after biasAdd.
* @param preluActivationWeights Tensor of prelu weights to be applied as part
* of a `prelu` activation, typically the same shape as `x`.
* @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
* activation.
*/
function fusedConv2d_(_ref) {
var x = _ref.x,
filter = _ref.filter,
strides = _ref.strides,
pad = _ref.pad,
_ref$dataFormat = _ref.dataFormat,
dataFormat = _ref$dataFormat === void 0 ? 'NHWC' : _ref$dataFormat,
_ref$dilations = _ref.dilations,
dilations = _ref$dilations === void 0 ? [1, 1] : _ref$dilations,
dimRoundingMode = _ref.dimRoundingMode,
bias = _ref.bias,
_ref$activation = _ref.activation,
activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
preluActivationWeights = _ref.preluActivationWeights,
leakyreluAlpha = _ref.leakyreluAlpha;
activation = activation || 'linear';
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = conv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
}
var $x = convertToTensor(x, 'x', 'conv2d');
var $filter = convertToTensor(filter, 'filter', 'conv2d');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in fused conv2d: input must be rank 4, but got rank " + (x4D.rank + ".");
});
assert($filter.rank === 4, function () {
return "Error in fused conv2d: filter must be rank 4, but got rank " + ($filter.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in fused conv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
assert(x4D.shape[3] === $filter.shape[2], function () {
return "Error in conv2d: depth of input (" + x4D.shape[3] + ") must match " + ("input depth for filter " + $filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in conv2D: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
assert(dataFormat === 'NHWC', function () {
return "Error in conv2d: got dataFormat of " + dataFormat + " but only NHWC is currently supported.";
});
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused conv2d');
var _makeTypesMatch = makeTypesMatch($bias, $x);
$bias = _makeTypesMatch[0];
assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d');
}
var grad = function grad(dy, saved) {
var $filter = saved[0],
x4D = saved[1],
y = saved[2],
$bias = saved[3];
var dyActivation = getFusedDyActivation(dy, y, activation);
assert(tupleValuesAreOne(dilations), function () {
return 'Error in gradient of fused conv2D: ' + "dilation rates greater than 1 " + ("are not yet supported in gradients. Got dilations '" + dilations + "'");
});
var xDer = conv2DBackpropInput(x4D.shape, dyActivation, $filter, strides, pad);
var filterDer = conv2DBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad);
var der = [xDer, filterDer];
if ($bias != null) {
var biasDer = getFusedBiasGradient($bias, dyActivation);
der.push(biasDer);
}
return der;
};
var inputs = {
x: x4D,
filter: $filter,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode,
activation: activation,
leakyreluAlpha: leakyreluAlpha
}; // Depending on the the params passed in we will have different number of
// inputs and thus a a different number of elements in the gradient.
if (bias == null) {
var customOp = customGrad(function (x4D, filter, save) {
var res = // tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(FusedConv2D, inputs, attrs);
save([filter, x4D, res]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {
value: res,
gradFunc: grad
};
});
return customOp(x4D, $filter);
} else {
var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
var res = ENGINE.runKernel(FusedConv2D, inputs, attrs);
save([filter, x4D, res, bias]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {
value: res,
gradFunc: grad
};
});
return customOpWithBias(x4D, $filter, $bias);
}
}
var conv2d$1 = op({
fusedConv2d_: fusedConv2d_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1];
}
var x4D = x;
if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
var dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
var inputs = {
x: x4D,
dy: dy4D
};
var attrs = {
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode,
dilations: dilations,
filterShape: filterShape
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
}
var depthwiseConv2dNativeBackpropFilter = op({
depthwiseConv2dNativeBackpropFilter_: depthwiseConv2dNativeBackpropFilter_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations, dimRoundingMode) {
if (dilations === void 0) {
dilations = [1, 1];
}
var dy4D = dy;
var reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
var inputs = {
dy: dy4D,
filter: filter
};
var attrs = {
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode,
dilations: dilations,
inputShape: xShape
};
var res = // tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var depthwiseConv2dNativeBackpropInput = op({
depthwiseConv2dNativeBackpropInput_: depthwiseConv2dNativeBackpropInput_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes depthwise 2D convolution, optionally fused with adding a
* bias and applying an activation.
*
* Given a 4D `input` array and a `filter` array of shape
* `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing
* `inChannels` convolutional filters of depth 1, this op applies a
* different filter to each input channel (expanding from 1 channel to
* `channelMultiplier` channels for each), then concatenates the results
* together. The output has `inChannels * channelMultiplier` channels.
*
* See
* [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d](
* https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d)
* for more details.
*
* @param obj An object with the following properties:
* @param x The input tensor, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is
* assumed.
* @param filter The filter tensor, rank 4, of shape
* `[filterHeight, filterWidth, inChannels, channelMultiplier]`.
* @param strides The strides of the convolution: `[strideHeight,
* strideWidth]`. If strides is a single number, then `strideHeight ==
* strideWidth`.
* @param pad The type of padding algorithm.
* - `same` and stride 1: output will be of same size as input,
* regardless of filter size.
* - `valid`: output will be smaller than input if filter is larger
* than 1x1.
* - For more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
* in which we sample input values across the height and width dimensions
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
* number, then `dilationHeight == dilationWidth`. If it is greater than
* 1, then all values of `strides` must be 1.
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
* "NHWC". Specify the data format of the input and output data. With the
* default format "NHWC", the data is stored in the order of: [batch,
* height, width, channels]. Only "NHWC" is currently supported.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
* @param bias Tensor to be added to the result.
* @param activation Name of activation kernel (defaults to `linear`).
* @param preluActivationWeights Tensor of prelu weights to be applied as part
* of a `prelu` activation, typically the same shape as `x`.
* @param leakyreluAlpha Optional. Alpha to be applied as part of a `leakyrelu`
* activation.
*/
function fusedDepthwiseConv2d_(_ref) {
var x = _ref.x,
filter = _ref.filter,
strides = _ref.strides,
pad = _ref.pad,
_ref$dataFormat = _ref.dataFormat,
dataFormat = _ref$dataFormat === void 0 ? 'NHWC' : _ref$dataFormat,
_ref$dilations = _ref.dilations,
dilations = _ref$dilations === void 0 ? [1, 1] : _ref$dilations,
dimRoundingMode = _ref.dimRoundingMode,
bias = _ref.bias,
_ref$activation = _ref.activation,
activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
preluActivationWeights = _ref.preluActivationWeights,
leakyreluAlpha = _ref.leakyreluAlpha;
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = depthwiseConv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
}
var $x = convertToTensor(x, 'x', 'depthwiseConv2d');
var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d');
var x4D = $x;
var reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert(x4D.rank === 4, function () {
return "Error in fused depthwiseConv2d: input must be rank 4, but got " + ("rank " + x4D.rank + ".");
});
assert($filter.rank === 4, function () {
return "Error in fused depthwiseConv2d: filter must be rank 4, " + ("but got rank " + $filter.rank + ".");
});
assert(x4D.shape[3] === $filter.shape[2], function () {
return "Error in fused depthwiseConv2d: number of input channels " + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + ("filter " + $filter.shape[2] + ".");
});
if (dilations == null) {
dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in fused depthwiseConv2d: Either strides or dilations must ' + ("be 1. Got strides " + strides + " and dilations '" + dilations + "'");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in fused depthwiseConv2d: pad must be an integer when " + ("using dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true
/* depthwise */
);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused conv2d');
var _makeTypesMatch = makeTypesMatch($bias, $x);
$bias = _makeTypesMatch[0];
assertAndGetBroadcastShape(convInfo.outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d');
}
var grad = function grad(dy, saved) {
assert(tupleValuesAreOne(dilations), function () {
return 'Error in gradient of fused depthwiseConv2d: dilation rates ' + "greater than 1 are not yet supported. Got dilations " + ("'" + dilations + "'");
});
var $filter = saved[0],
x4D = saved[1],
y = saved[2],
bias = saved[3];
var dyActivation = getFusedDyActivation(dy, y, activation);
var xDer = depthwiseConv2dNativeBackpropInput(x4D.shape, dyActivation, $filter, strides, pad, dilations, dimRoundingMode);
var filterDer = depthwiseConv2dNativeBackpropFilter(x4D, dyActivation, $filter.shape, strides, pad, dilations, dimRoundingMode);
if (bias != null) {
var biasDer = getFusedBiasGradient($bias, dyActivation);
return [xDer, filterDer, biasDer];
}
return [xDer, filterDer];
};
var inputs = {
x: x4D,
filter: $filter,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode,
activation: activation,
leakyreluAlpha: leakyreluAlpha
}; // Depending on the the params passed in we will have different number of
// inputs and thus a a different number of elements in the gradient.
if (bias == null) {
var customOp = customGrad(function (x4D, filter, save) {
// tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
save([filter, x4D, res]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {
value: res,
gradFunc: grad
};
});
return customOp(x4D, $filter);
} else {
var customOpWithBias = customGrad(function (x4D, filter, bias, save) {
// tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(FusedDepthwiseConv2D, inputs, attrs);
save([filter, x4D, res, bias]);
if (reshapedTo4D) {
// tslint:disable-next-line: no-unnecessary-type-assertion
res = reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return {
value: res,
gradFunc: grad
};
});
return customOpWithBias(x4D, $filter, $bias);
}
}
var depthwiseConv2d$1 = op({
fusedDepthwiseConv2d_: fusedDepthwiseConv2d_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the dot product of two matrices with optional activation and bias.
*
* ```js
* const a = tf.tensor2d([-1, -2], [1, 2]);
* const b = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const bias = tf.tensor2d([1, 2], [1, 2]);
*
* tf.fused.matMul({a, b, bias, activation: 'relu'}).print();
* ```
*
* @param obj An object with the following properties:
* - `a` First matrix in dot product operation.
* - `b` Second matrix in dot product operation.
* - `transposeA` If true, `a` is transposed before multiplication.
* - `transposeB` If true, `b` is transposed before multiplication.
* - `bias` Matrix to be added to the result.
* - `activation` Name of activation kernel (defaults to `linear`).
* - `preluActivationWeights` Tensor of prelu weights.
* - `leakyreluAlpha` Alpha of leakyrelu.
*/
function fusedMatMul_(_ref) {
var a = _ref.a,
b = _ref.b,
_ref$transposeA = _ref.transposeA,
transposeA = _ref$transposeA === void 0 ? false : _ref$transposeA,
_ref$transposeB = _ref.transposeB,
transposeB = _ref$transposeB === void 0 ? false : _ref$transposeB,
bias = _ref.bias,
_ref$activation = _ref.activation,
activation = _ref$activation === void 0 ? 'linear' : _ref$activation,
preluActivationWeights = _ref.preluActivationWeights,
leakyreluAlpha = _ref.leakyreluAlpha;
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
var result = matMul(a, b, transposeA, transposeB);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha);
}
var $a = convertToTensor(a, 'a', 'fused matMul');
var $b = convertToTensor(b, 'b', 'fused matMul');
var _makeTypesMatch = makeTypesMatch($a, $b);
$a = _makeTypesMatch[0];
$b = _makeTypesMatch[1];
var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
var outerDimsA = $a.shape.slice(0, -2);
var outerDimsB = $b.shape.slice(0, -2);
var batchDimA = sizeFromShape(outerDimsA);
var batchDimB = sizeFromShape(outerDimsB);
assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, function () {
return "Error in fused matMul: inputs must have the same rank of at " + ("least 2, got ranks " + $a.rank + " and " + $b.rank + ".");
});
assert(arraysEqual(outerDimsA, outerDimsB), function () {
return "Error in fused matMul: outer dimensions (" + outerDimsA + ") and (" + (outerDimsB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " must match.");
});
assert(innerShapeA === innerShapeB, function () {
return "Error in fused matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + ($b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]);
var a3D = transposeA ? reshape($a, [batchDimA, innerShapeA, outerShapeA]) : reshape($a, [batchDimA, outerShapeA, innerShapeA]);
var b3D = transposeB ? reshape($b, [batchDimB, outerShapeB, innerShapeB]) : reshape($b, [batchDimB, innerShapeB, outerShapeB]);
var $bias;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused matMul');
var _makeTypesMatch2 = makeTypesMatch($bias, $a);
$bias = _makeTypesMatch2[0];
assertAndGetBroadcastShape(outShape, $bias.shape);
}
var $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
}
var grad = function grad(dy, saved) {
var a3D = saved[0],
b3D = saved[1],
y = saved[2],
$bias = saved[3]; // we reshape dy because the result of the forward is not
// necessarily going to be a 3d tensor due to a reshape done at the end of
// the customOp.
var dyActivation = getFusedDyActivation(reshape(dy, y.shape), y, activation);
var aDer;
var bDer;
if (!transposeA && !transposeB) {
aDer = matMul(dyActivation, b3D, false, true);
bDer = matMul(a3D, dyActivation, true, false);
} else if (!transposeA && transposeB) {
aDer = matMul(dyActivation, b3D, false, false);
bDer = matMul(dyActivation, a3D, true, false);
} else if (transposeA && !transposeB) {
aDer = matMul(b3D, dyActivation, false, true);
bDer = matMul(a3D, dyActivation, false, false);
} else {
aDer = matMul(b3D, dyActivation, true, true);
bDer = matMul(dyActivation, a3D, true, true);
}
if (bias != null) {
var biasDer = getFusedBiasGradient($bias, dyActivation);
return [aDer, bDer, biasDer];
} else {
return [aDer, bDer];
}
};
var inputs = {
a: a3D,
b: b3D,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
var attrs = {
transposeA: transposeA,
transposeB: transposeB,
activation: activation,
leakyreluAlpha: leakyreluAlpha
}; // Depending on the the params passed in we will have different number of
// inputs and thus a a different number of elements in the gradient.
if (bias == null) {
var customOp = customGrad(function (a3D, b3D, save) {
var res = // tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
save([a3D, b3D, res]);
return {
value: reshape(res, outShape),
gradFunc: grad
};
});
return customOp(a3D, b3D);
} else {
var customOpWithBias = customGrad(function (a3D, b3D, $bias, save) {
var res = // tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
save([a3D, b3D, res, $bias]);
return {
value: reshape(res, outShape),
gradFunc: grad
};
});
return customOpWithBias(a3D, b3D, $bias);
}
}
var matMul$1 = op({
fusedMatMul_: fusedMatMul_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fused_ops = {
__proto__: null,
conv2d: conv2d$1,
depthwiseConv2d: depthwiseConv2d$1,
matMul: matMul$1
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Generate a hamming window.
*
* See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
*
* ```js
* tf.signal.hammingWindow(10).print();
* ```
* @param The length of window
*
* @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
*/
function hammingWindow_(windowLength) {
return cosineWindow(windowLength, 0.54, 0.46);
}
var hammingWindow = op({
hammingWindow_: hammingWindow_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Generate a Hann window.
*
* See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
*
* ```js
* tf.signal.hannWindow(10).print();
* ```
* @param The length of window
*
* @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
*/
function hannWindow_(windowLength) {
return cosineWindow(windowLength, 0.5, 0.5);
}
var hannWindow = op({
hannWindow_: hannWindow_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Expands input into frames of frameLength.
* Slides a window size with frameStep.
*
* ```js
* tf.signal.frame([1, 2, 3], 2, 1).print();
* ```
* @param signal The input tensor to be expanded
* @param frameLength Length of each frame
* @param frameStep The frame hop size in samples.
* @param padEnd Whether to pad the end of signal with padValue.
* @param padValue An number to use where the input signal does
* not exist when padEnd is True.
*
* @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
*/
function frame_(signal, frameLength, frameStep, padEnd, padValue) {
if (padEnd === void 0) {
padEnd = false;
}
if (padValue === void 0) {
padValue = 0;
}
var start = 0;
var output = [];
while (start + frameLength <= signal.size) {
output.push(slice$2(signal, start, frameLength));
start += frameStep;
}
if (padEnd) {
while (start < signal.size) {
var padLen = start + frameLength - signal.size;
var pad = concat([slice$2(signal, start, frameLength - padLen), fill([padLen], padValue)]);
output.push(pad);
start += frameStep;
}
}
if (output.length === 0) {
return tensor2d([], [0, frameLength]);
}
return reshape(concat(output), [output.length, frameLength]);
}
var frame = op({
frame_: frame_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the Short-time Fourier Transform of signals
* See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform
*
* ```js
* const input = tf.tensor1d([1, 1, 1, 1, 1])
* tf.signal.stft(input, 3, 1).print();
* ```
* @param signal 1-dimensional real value tensor.
* @param frameLength The window length of samples.
* @param frameStep The number of samples to step.
* @param fftLength The size of the FFT to apply.
* @param windowFn A callable that takes a window length and returns 1-d tensor.
*
* @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'}
*/
function stft_(signal, frameLength, frameStep, fftLength, windowFn) {
if (windowFn === void 0) {
windowFn = hannWindow;
}
if (fftLength == null) {
fftLength = enclosingPowerOfTwo(frameLength);
}
var framedSignal = frame(signal, frameLength, frameStep);
var windowedSignal = mul(framedSignal, windowFn(frameLength));
return rfft(windowedSignal, fftLength);
}
var stft = op({
stft_: stft_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Extracts crops from the input image tensor and resizes them using bilinear
* sampling or nearest neighbor sampling (possibly with aspect ratio change)
* to a common output size specified by cropSize.
*
* @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`,
* where imageHeight and imageWidth must be positive, specifying the
* batch of images from which to take crops
* @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized
* coordinates of the box in the boxInd[i]'th image in the batch
* @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range
* `[0, batch)` that specifies the image that the `i`-th box refers to.
* @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]`
* specifying the size to which all crops are resized to.
* @param method Optional string from `'bilinear' | 'nearest'`,
* defaults to bilinear, which specifies the sampling method for resizing
* @param extrapolationValue A threshold for deciding when to remove boxes based
* on score. Defaults to 0.
* @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]`
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function cropAndResize_(image, boxes, boxInd, cropSize, method, extrapolationValue) {
if (method === void 0) {
method = 'bilinear';
}
if (extrapolationValue === void 0) {
extrapolationValue = 0;
}
var $image = convertToTensor(image, 'image', 'cropAndResize');
var $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32');
var $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32');
var numBoxes = $boxes.shape[0];
assert($image.rank === 4, function () {
return 'Error in cropAndResize: image must be rank 4,' + ("but got rank " + $image.rank + ".");
});
assert($boxes.rank === 2 && $boxes.shape[1] === 4, function () {
return "Error in cropAndResize: boxes must be have size [" + numBoxes + ",4] " + ("but had shape " + $boxes.shape + ".");
});
assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, function () {
return "Error in cropAndResize: boxInd must be have size [" + numBoxes + "] " + ("but had shape " + $boxes.shape + ".");
});
assert(cropSize.length === 2, function () {
return "Error in cropAndResize: cropSize must be of length 2, but got " + ("length " + cropSize.length + ".");
});
assert(cropSize[0] >= 1 && cropSize[1] >= 1, function () {
return "cropSize must be atleast [1,1], but was " + cropSize;
});
assert(method === 'bilinear' || method === 'nearest', function () {
return "method must be bilinear or nearest, but was " + method;
});
var inputs = {
image: $image,
boxes: $boxes,
boxInd: $boxInd
};
var attrs = {
method: method,
extrapolationValue: extrapolationValue,
cropSize: cropSize
};
var res = ENGINE.runKernel(CropAndResize, inputs, attrs);
return res;
}
var cropAndResize = op({
cropAndResize_: cropAndResize_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Flips the image left to right. Currently available in the CPU, WebGL, and
* WASM backends.
*
* @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
*/
/** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */
function flipLeftRight_(image) {
var $image = convertToTensor(image, 'image', 'flipLeftRight', 'float32');
assert($image.rank === 4, function () {
return 'Error in flipLeftRight: image must be rank 4,' + ("but got rank " + $image.rank + ".");
});
var inputs = {
image: $image
};
var res = ENGINE.runKernel(FlipLeftRight, inputs, {});
return res;
}
var flipLeftRight = op({
flipLeftRight_: flipLeftRight_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts images from grayscale to RGB format.
*
* @param image A grayscale tensor to convert. The `image`'s last dimension must
* be size 1 with at least a two-dimensional shape.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function grayscaleToRGB_(image) {
var $image = convertToTensor(image, 'image', 'grayscaleToRGB');
var lastDimsIdx = $image.rank - 1;
var lastDims = $image.shape[lastDimsIdx];
assert($image.rank >= 2, function () {
return 'Error in grayscaleToRGB: images must be at least rank 2, ' + ("but got rank " + $image.rank + ".");
});
assert(lastDims === 1, function () {
return 'Error in grayscaleToRGB: last dimension of a grayscale image ' + ("should be size 1, but got size " + lastDims + ".");
});
var reps = new Array($image.rank);
reps.fill(1, 0, lastDimsIdx);
reps[lastDimsIdx] = 3;
return tile($image, reps);
}
var grayscaleToRGB = op({
grayscaleToRGB_: grayscaleToRGB_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Rotates the input image tensor counter-clockwise with an optional offset
* center of rotation. Currently available in the CPU, WebGL, and WASM backends.
*
* @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
* @param radians The amount of rotation.
* @param fillValue The value to fill in the empty space leftover
* after rotation. Can be either a single grayscale value (0-255), or an
* array of three numbers `[red, green, blue]` specifying the red, green,
* and blue channels. Defaults to `0` (black).
* @param center The center of rotation. Can be either a single value (0-1), or
* an array of two numbers `[centerX, centerY]`. Defaults to `0.5` (rotates
* the image around its center).
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function rotateWithOffset_(image, radians, fillValue, center) {
if (fillValue === void 0) {
fillValue = 0;
}
if (center === void 0) {
center = 0.5;
}
var $image = convertToTensor(image, 'image', 'rotateWithOffset', 'float32');
assert($image.rank === 4, function () {
return 'Error in rotateWithOffset: image must be rank 4,' + ("but got rank " + $image.rank + ".");
});
var inputs = {
image: $image
};
var attrs = {
radians: radians,
fillValue: fillValue,
center: center
};
var res = ENGINE.runKernel(RotateWithOffset, inputs, attrs);
return res;
}
var rotateWithOffset = op({
rotateWithOffset_: rotateWithOffset_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
if (iouThreshold == null) {
iouThreshold = 0.5;
}
if (scoreThreshold == null) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma == null) {
softNmsSigma = 0.0;
}
var numBoxes = boxes.shape[0];
maxOutputSize = Math.min(maxOutputSize, numBoxes);
assert(0 <= iouThreshold && iouThreshold <= 1, function () {
return "iouThreshold must be in [0, 1], but was '" + iouThreshold + "'";
});
assert(boxes.rank === 2, function () {
return "boxes must be a 2D tensor, but was of rank '" + boxes.rank + "'";
});
assert(boxes.shape[1] === 4, function () {
return "boxes must have 4 columns, but 2nd dimension was " + boxes.shape[1];
});
assert(scores.rank === 1, function () {
return 'scores must be a 1D tensor';
});
assert(scores.shape[0] === numBoxes, function () {
return "scores has incompatible shape with boxes. Expected " + numBoxes + ", " + ("but was " + scores.shape[0]);
});
assert(0 <= softNmsSigma && softNmsSigma <= 1, function () {
return "softNmsSigma must be in [0, 1], but was '" + softNmsSigma + "'";
});
return {
maxOutputSize: maxOutputSize,
iouThreshold: iouThreshold,
scoreThreshold: scoreThreshold,
softNmsSigma: softNmsSigma
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Performs non maximum suppression of bounding boxes based on
* iou (intersection over union).
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @return A 1D tensor with the selected box indices.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
var inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
maxOutputSize = inputs.maxOutputSize;
iouThreshold = inputs.iouThreshold;
scoreThreshold = inputs.scoreThreshold;
var attrs = {
maxOutputSize: maxOutputSize,
iouThreshold: iouThreshold,
scoreThreshold: scoreThreshold
};
return ENGINE.runKernel(NonMaxSuppressionV3, {
boxes: $boxes,
scores: $scores
}, attrs);
}
var nonMaxSuppression = op({
nonMaxSuppression_: nonMaxSuppression_
});
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Inserts a value into a sorted array. This method allows duplicate, meaning it
* allows inserting duplicate value, in which case, the element will be inserted
* at the lowest index of the value.
* @param arr The array to modify.
* @param element The element to insert.
* @param comparator Optional. If no comparator is specified, elements are
* compared using array_util.defaultComparator, which is suitable for Strings
* and Numbers in ascending arrays. If the array contains multiple instances of
* the target value, the left-most instance will be returned. To provide a
* comparator, it should take 2 arguments to compare and return a negative,
* zero, or a positive number.
*/
function binaryInsert(arr, element, comparator) {
var index = binarySearch(arr, element, comparator);
var insertionPoint = index < 0 ? -(index + 1) : index;
arr.splice(insertionPoint, 0, element);
}
/**
* Searches the array for the target using binary search, returns the index
* of the found element, or position to insert if element not found. If no
* comparator is specified, elements are compared using array_
* util.defaultComparator, which is suitable for Strings and Numbers in
* ascending arrays. If the array contains multiple instances of the target
* value, the left-most instance will be returned.
* @param arr The array to be searched in.
* @param target The target to be searched for.
* @param comparator Should take 2 arguments to compare and return a negative,
* zero, or a positive number.
* @return Lowest index of the target value if found, otherwise the insertion
* point where the target should be inserted, in the form of
* (-insertionPoint - 1).
*/
function binarySearch(arr, target, comparator) {
return binarySearch_(arr, target, comparator || defaultComparator);
}
/**
* Compares its two arguments for order.
* @param a The first element to be compared.
* @param b The second element to be compared.
* @return A negative number, zero, or a positive number as the first
* argument is less than, equal to, or greater than the second.
*/
function defaultComparator(a, b) {
return a > b ? 1 : a < b ? -1 : 0;
}
function binarySearch_(arr, target, comparator) {
var left = 0;
var right = arr.length;
var middle = 0;
var found = false;
while (left < right) {
middle = left + (right - left >>> 1);
var compareResult = comparator(target, arr[middle]);
if (compareResult > 0) {
left = middle + 1;
} else {
right = middle; // If compareResult is 0, the value is found. We record it is found,
// and then keep looking because there may be duplicate.
found = !compareResult;
}
}
return found ? left : -left - 1;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function nonMaxSuppressionV3Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0
/* softNmsSigma */
);
}
function nonMaxSuppressionV4Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0
/* softNmsSigma */
, false
/* returnScoresTensor */
, padToMaxOutputSize
/* padToMaxOutputSize */
, true
/* returnValidOutputs */
);
}
function nonMaxSuppressionV5Impl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true
/* returnScoresTensor */
);
}
function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor, padToMaxOutputSize, returnValidOutputs) {
if (returnScoresTensor === void 0) {
returnScoresTensor = false;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
if (returnValidOutputs === void 0) {
returnValidOutputs = false;
}
// The list is sorted in ascending order, so that we can always pop the
// candidate with the largest score in O(1) time.
var candidates = [];
for (var i = 0; i < scores.length; i++) {
if (scores[i] > scoreThreshold) {
candidates.push({
score: scores[i],
boxIndex: i,
suppressBeginIndex: 0
});
}
}
candidates.sort(ascendingComparator); // If softNmsSigma is 0, the outcome of this algorithm is exactly same as
// before.
var scale = softNmsSigma > 0 ? -0.5 / softNmsSigma : 0.0;
var selectedIndices = [];
var selectedScores = [];
while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
var candidate = candidates.pop();
var originalScore = candidate.score,
boxIndex = candidate.boxIndex,
suppressBeginIndex = candidate.suppressBeginIndex;
if (originalScore < scoreThreshold) {
break;
} // Overlapping boxes are likely to have similar scores, therefore we
// iterate through the previously selected boxes backwards in order to
// see if candidate's score should be suppressed. We use
// suppressBeginIndex to track and ensure a candidate can be suppressed
// by a selected box no more than once. Also, if the overlap exceeds
// iouThreshold, we simply ignore the candidate.
var ignoreCandidate = false;
for (var j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
var iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
if (iou >= iouThreshold) {
ignoreCandidate = true;
break;
}
candidate.score = candidate.score * suppressWeight(iouThreshold, scale, iou);
if (candidate.score <= scoreThreshold) {
break;
}
} // At this point, if `candidate.score` has not dropped below
// `scoreThreshold`, then we know that we went through all of the
// previous selections and can safely update `suppressBeginIndex` to the
// end of the selected array. Then we can re-insert the candidate with
// the updated score and suppressBeginIndex back in the candidate list.
// If on the other hand, `candidate.score` has dropped below the score
// threshold, we will not add it back to the candidates list.
candidate.suppressBeginIndex = selectedIndices.length;
if (!ignoreCandidate) {
// Candidate has passed all the tests, and is not suppressed, so
// select the candidate.
if (candidate.score === originalScore) {
selectedIndices.push(boxIndex);
selectedScores.push(candidate.score);
} else if (candidate.score > scoreThreshold) {
// Candidate's score is suppressed but is still high enough to be
// considered, so add back to the candidates list.
binaryInsert(candidates, candidate, ascendingComparator);
}
}
} // NonMaxSuppressionV4 feature: padding output to maxOutputSize.
var validOutputs = selectedIndices.length;
var elemsToPad = maxOutputSize - validOutputs;
if (padToMaxOutputSize && elemsToPad > 0) {
selectedIndices.push.apply(selectedIndices, new Array(elemsToPad).fill(0));
selectedScores.push.apply(selectedScores, new Array(elemsToPad).fill(0.0));
}
var result = {
selectedIndices: selectedIndices
};
if (returnScoresTensor) {
result['selectedScores'] = selectedScores;
}
if (returnValidOutputs) {
result['validOutputs'] = validOutputs;
}
return result;
}
function intersectionOverUnion(boxes, i, j) {
var iCoord = boxes.subarray(i * 4, i * 4 + 4);
var jCoord = boxes.subarray(j * 4, j * 4 + 4);
var yminI = Math.min(iCoord[0], iCoord[2]);
var xminI = Math.min(iCoord[1], iCoord[3]);
var ymaxI = Math.max(iCoord[0], iCoord[2]);
var xmaxI = Math.max(iCoord[1], iCoord[3]);
var yminJ = Math.min(jCoord[0], jCoord[2]);
var xminJ = Math.min(jCoord[1], jCoord[3]);
var ymaxJ = Math.max(jCoord[0], jCoord[2]);
var xmaxJ = Math.max(jCoord[1], jCoord[3]);
var areaI = (ymaxI - yminI) * (xmaxI - xminI);
var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
if (areaI <= 0 || areaJ <= 0) {
return 0.0;
}
var intersectionYmin = Math.max(yminI, yminJ);
var intersectionXmin = Math.max(xminI, xminJ);
var intersectionYmax = Math.min(ymaxI, ymaxJ);
var intersectionXmax = Math.min(xmaxI, xmaxJ);
var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0);
return intersectionArea / (areaI + areaJ - intersectionArea);
} // A Gaussian penalty function, this method always returns values in [0, 1].
// The weight is a function of similarity, the more overlap two boxes are, the
// smaller the weight is, meaning highly overlapping boxe will be significantly
// penalized. On the other hand, a non-overlapping box will not be penalized.
function suppressWeight(iouThreshold, scale, iou) {
var weight = Math.exp(scale * iou * iou);
return iou <= iouThreshold ? weight : 0.0;
}
function ascendingComparator(c1, c2) {
// For objects with same scores, we make the object with the larger index go
// first. In an array that pops from the end, this means that the object with
// the smaller index will be popped first. This ensures the same output as
// the TensorFlow python version.
return c1.score - c2.score || c1.score === c2.score && c2.boxIndex - c1.boxIndex;
}
/**
* Performs non maximum suppression of bounding boxes based on
* iou (intersection over union).
*
* This is the async version of `nonMaxSuppression`
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @return A 1D tensor with the selected box indices.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppressionAsync_(_x, _x2, _x3, _x4, _x5) {
return _nonMaxSuppressionAsync_.apply(this, arguments);
}
function _nonMaxSuppressionAsync_() {
_nonMaxSuppressionAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
var $boxes, $scores, inputs, boxesAndScores, boxesVals, scoresVals, _nonMaxSuppressionV3I, selectedIndices;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
$boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
$scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold);
maxOutputSize = inputs.maxOutputSize;
iouThreshold = inputs.iouThreshold;
scoreThreshold = inputs.scoreThreshold;
_context.next = 10;
return Promise.all([$boxes.data(), $scores.data()]);
case 10:
boxesAndScores = _context.sent;
boxesVals = boxesAndScores[0];
scoresVals = boxesAndScores[1]; // We call a cpu based impl directly with the typedarray data here rather
// than a kernel because all kernels are synchronous (and thus cannot await
// .data()).
_nonMaxSuppressionV3I = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold), selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return _context.abrupt("return", tensor1d(selectedIndices, 'int32'));
case 17:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _nonMaxSuppressionAsync_.apply(this, arguments);
}
var nonMaxSuppressionAsync = nonMaxSuppressionAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Performs non maximum suppression of bounding boxes based on
* iou (intersection over union).
*
* This op also supports a Soft-NMS mode (c.f.
* Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
* of other overlapping boxes, therefore favoring different regions of the image
* with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
* parameter to be larger than 0.
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @param softNmsSigma A float representing the sigma parameter for Soft NMS.
* When sigma is 0, it falls back to nonMaxSuppression.
* @return A map with the following properties:
* - selectedIndices: A 1D tensor with the selected box indices.
* - selectedScores: A 1D tensor with the corresponding scores for each
* selected box.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppressionWithScore_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma === void 0) {
softNmsSigma = 0.0;
}
var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
maxOutputSize = params.maxOutputSize;
iouThreshold = params.iouThreshold;
scoreThreshold = params.scoreThreshold;
softNmsSigma = params.softNmsSigma;
var inputs = {
boxes: $boxes,
scores: $scores
};
var attrs = {
maxOutputSize: maxOutputSize,
iouThreshold: iouThreshold,
scoreThreshold: scoreThreshold,
softNmsSigma: softNmsSigma
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var result = ENGINE.runKernel(NonMaxSuppressionV5, inputs, attrs);
return {
selectedIndices: result[0],
selectedScores: result[1]
};
}
var nonMaxSuppressionWithScore = op({
nonMaxSuppressionWithScore_: nonMaxSuppressionWithScore_
});
/**
* Asynchronously performs non maximum suppression of bounding boxes based on
* iou (intersection over union).
*
* This op also supports a Soft-NMS mode (c.f.
* Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
* of other overlapping boxes, therefore favoring different regions of the image
* with high scores. To enable this Soft-NMS mode, set the `softNmsSigma`
* parameter to be larger than 0.
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @param softNmsSigma A float representing the sigma parameter for Soft NMS.
* When sigma is 0, it falls back to nonMaxSuppression.
* @return A map with the following properties:
* - selectedIndices: A 1D tensor with the selected box indices.
* - selectedScores: A 1D tensor with the corresponding scores for each
* selected box.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppressionWithScoreAsync_(_x, _x2, _x3, _x4, _x5, _x6) {
return _nonMaxSuppressionWithScoreAsync_.apply(this, arguments);
}
function _nonMaxSuppressionWithScoreAsync_() {
_nonMaxSuppressionWithScoreAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
var $boxes, $scores, params, boxesAndScores, boxesVals, scoresVals, _nonMaxSuppressionV5I, selectedIndices, selectedScores;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (softNmsSigma === void 0) {
softNmsSigma = 0.0;
}
$boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
$scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
maxOutputSize = params.maxOutputSize;
iouThreshold = params.iouThreshold;
scoreThreshold = params.scoreThreshold;
softNmsSigma = params.softNmsSigma;
_context.next = 12;
return Promise.all([$boxes.data(), $scores.data()]);
case 12:
boxesAndScores = _context.sent;
boxesVals = boxesAndScores[0];
scoresVals = boxesAndScores[1]; // We call a cpu based impl directly with the typedarray data here rather
// than a kernel because all kernels are synchronous (and thus cannot await
// .data()).
_nonMaxSuppressionV5I = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma), selectedIndices = _nonMaxSuppressionV5I.selectedIndices, selectedScores = _nonMaxSuppressionV5I.selectedScores;
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return _context.abrupt("return", {
selectedIndices: tensor1d(selectedIndices, 'int32'),
selectedScores: tensor1d(selectedScores)
});
case 19:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _nonMaxSuppressionWithScoreAsync_.apply(this, arguments);
}
var nonMaxSuppressionWithScoreAsync = nonMaxSuppressionWithScoreAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Asynchronously performs non maximum suppression of bounding boxes based on
* iou (intersection over union), with an option to pad results.
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @param padToMaxOutputSize Defalts to false. If true, size of output
* `selectedIndices` is padded to maxOutputSize.
* @return A map with the following properties:
* - selectedIndices: A 1D tensor with the selected box indices.
* - validOutputs: A scalar denoting how many elements in `selectedIndices`
* are valid. Valid elements occur first, then padding.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppressionPadded_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression');
var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression');
var params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null
/* softNmsSigma */
);
var $maxOutputSize = params.maxOutputSize;
var $iouThreshold = params.iouThreshold;
var $scoreThreshold = params.scoreThreshold;
var inputs = {
boxes: $boxes,
scores: $scores
};
var attrs = {
maxOutputSize: $maxOutputSize,
iouThreshold: $iouThreshold,
scoreThreshold: $scoreThreshold,
padToMaxOutputSize: padToMaxOutputSize
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var result = ENGINE.runKernel(NonMaxSuppressionV4, inputs, attrs);
return {
selectedIndices: result[0],
validOutputs: result[1]
};
}
var nonMaxSuppressionPadded = op({
nonMaxSuppressionPadded_: nonMaxSuppressionPadded_
});
/**
* Asynchronously performs non maximum suppression of bounding boxes based on
* iou (intersection over union), with an option to pad results.
*
* @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is
* `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of
* the bounding box.
* @param scores a 1d tensor providing the box scores of shape `[numBoxes]`.
* @param maxOutputSize The maximum number of boxes to be selected.
* @param iouThreshold A float representing the threshold for deciding whether
* boxes overlap too much with respect to IOU. Must be between [0, 1].
* Defaults to 0.5 (50% box overlap).
* @param scoreThreshold A threshold for deciding when to remove boxes based
* on score. Defaults to -inf, which means any score is accepted.
* @param padToMaxOutputSize Defalts to false. If true, size of output
* `selectedIndices` is padded to maxOutputSize.
* @return A map with the following properties:
* - selectedIndices: A 1D tensor with the selected box indices.
* - validOutputs: A scalar denoting how many elements in `selectedIndices`
* are valid. Valid elements occur first, then padding.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function nonMaxSuppressionPaddedAsync_(_x, _x2, _x3, _x4, _x5, _x6) {
return _nonMaxSuppressionPaddedAsync_.apply(this, arguments);
}
function _nonMaxSuppressionPaddedAsync_() {
_nonMaxSuppressionPaddedAsync_ = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
var $boxes, $scores, params, $maxOutputSize, $iouThreshold, $scoreThreshold, _yield$Promise$all, boxesVals, scoresVals, _nonMaxSuppressionV4I, selectedIndices, validOutputs;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (iouThreshold === void 0) {
iouThreshold = 0.5;
}
if (scoreThreshold === void 0) {
scoreThreshold = Number.NEGATIVE_INFINITY;
}
if (padToMaxOutputSize === void 0) {
padToMaxOutputSize = false;
}
$boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync');
$scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync');
params = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold, null
/* softNmsSigma */
);
$maxOutputSize = params.maxOutputSize;
$iouThreshold = params.iouThreshold;
$scoreThreshold = params.scoreThreshold;
_context.next = 11;
return Promise.all([$boxes.data(), $scores.data()]);
case 11:
_yield$Promise$all = _context.sent;
boxesVals = _yield$Promise$all[0];
scoresVals = _yield$Promise$all[1];
// We call a cpu based impl directly with the typedarray data here rather
// than a kernel because all kernels are synchronous (and thus cannot await
// .data()).
_nonMaxSuppressionV4I = nonMaxSuppressionV4Impl(boxesVals, scoresVals, $maxOutputSize, $iouThreshold, $scoreThreshold, padToMaxOutputSize), selectedIndices = _nonMaxSuppressionV4I.selectedIndices, validOutputs = _nonMaxSuppressionV4I.validOutputs;
if ($boxes !== boxes) {
$boxes.dispose();
}
if ($scores !== scores) {
$scores.dispose();
}
return _context.abrupt("return", {
selectedIndices: tensor1d(selectedIndices, 'int32'),
validOutputs: scalar(validOutputs, 'int32')
});
case 18:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _nonMaxSuppressionPaddedAsync_.apply(this, arguments);
}
var nonMaxSuppressionPaddedAsync = nonMaxSuppressionPaddedAsync_;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Bilinear resize a single 3D image or a batch of 3D images to a new shape.
*
* @param images The images, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param size The new shape `[newHeight, newWidth]` to resize the
* images to. Each channel is resized individually.
* @param alignCorners Defaults to `false`. If true, rescale
* input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
* corners of images and resized images. If false, rescale by
* `new_height / height`. Treat similarly the width dimension.
* @param halfPixelCenters Defaults to `false`. Whether to assume pixel centers
* are at 0.5, which would make the floating point coordinates of the top
* left pixel 0.5, 0.5.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function resizeBilinear_(images, size, alignCorners, halfPixelCenters) {
if (alignCorners === void 0) {
alignCorners = false;
}
if (halfPixelCenters === void 0) {
halfPixelCenters = false;
}
var $images = convertToTensor(images, 'images', 'resizeBilinear');
assert($images.rank === 3 || $images.rank === 4, function () {
return "Error in resizeBilinear: x must be rank 3 or 4, but got " + ("rank " + $images.rank + ".");
});
assert(size.length === 2, function () {
return "Error in resizeBilinear: new shape must 2D, but got shape " + (size + ".");
});
assert(halfPixelCenters === false || alignCorners === false, function () {
return "Error in resizeBilinear: If halfPixelCenters is true, " + "alignCorners must be false.";
});
var batchImages = $images;
var reshapedTo4D = false;
if ($images.rank === 3) {
reshapedTo4D = true;
batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
}
var inputs = {
images: batchImages
};
var attrs = {
alignCorners: alignCorners,
halfPixelCenters: halfPixelCenters,
size: size
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(ResizeBilinear, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var resizeBilinear = op({
resizeBilinear_: resizeBilinear_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* NearestNeighbor resize a batch of 3D images to a new shape.
*
* @param images The images, of rank 4 or rank 3, of shape
* `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed.
* @param size The new shape `[newHeight, newWidth]` to resize the
* images to. Each channel is resized individually.
* @param alignCorners Defaults to False. If true, rescale
* input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4
* corners of images and resized images. If false, rescale by
* `new_height / height`. Treat similarly the width dimension.
* @param halfPixelCenters Defaults to `false`. Whether to assumes pixels are of
* half the actual dimensions, and yields more accurate resizes. This flag
* would also make the floating point coordinates of the top left pixel
* 0.5, 0.5.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function resizeNearestNeighbor_(images, size, alignCorners, halfPixelCenters) {
if (alignCorners === void 0) {
alignCorners = false;
}
if (halfPixelCenters === void 0) {
halfPixelCenters = false;
}
var $images = convertToTensor(images, 'images', 'resizeNearestNeighbor');
assert($images.rank === 3 || $images.rank === 4, function () {
return "Error in resizeNearestNeighbor: x must be rank 3 or 4, but got " + ("rank " + $images.rank + ".");
});
assert(size.length === 2, function () {
return "Error in resizeNearestNeighbor: new shape must 2D, but got shape " + (size + ".");
});
assert($images.dtype === 'float32' || $images.dtype === 'int32', function () {
return '`images` must have `int32` or `float32` as dtype';
});
assert(halfPixelCenters === false || alignCorners === false, function () {
return "Error in resizeNearestNeighbor: If halfPixelCenters is true, " + "alignCorners must be false.";
});
var batchImages = $images;
var reshapedTo4D = false;
if ($images.rank === 3) {
reshapedTo4D = true;
batchImages = reshape($images, [1, $images.shape[0], $images.shape[1], $images.shape[2]]);
}
var inputs = {
images: batchImages
};
var attrs = {
alignCorners: alignCorners,
halfPixelCenters: halfPixelCenters,
size: size
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(ResizeNearestNeighbor, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var resizeNearestNeighbor = op({
resizeNearestNeighbor_: resizeNearestNeighbor_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Performs image binarization with corresponding threshold
* (depends on the method)value, which creates a binary image from a grayscale.
* @param image 3d tensor of shape [imageHeight,imageWidth, depth],
* where imageHeight and imageWidth must be positive.The image color
* range should be [0, 255].
* @param method Optional string from `'binary' | 'otsu'`
* which specifies the method for thresholding. Defaults to 'binary'.
* @param inverted Optional boolean whichspecifies
* if colours should be inverted. Defaults to false.
* @param threshValue Optional number which defines threshold value from 0 to 1.
* Defaults to 0.5.
* @return A 3d tensor of shape [imageHeight,imageWidth, depth], which
* contains binarized image.
*/
function threshold_(image, method, inverted, threshValue) {
if (method === void 0) {
method = 'binary';
}
if (inverted === void 0) {
inverted = false;
}
if (threshValue === void 0) {
threshValue = 0.5;
}
var $image = convertToTensor(image, 'image', 'threshold');
/* 0.2989, 0.5870, 0.1140 are represent luma coefficients in CCIR601.
Reference for converting between RGB and grayscale: https://en.wikipedia.org/wiki/Luma_%28video%29 */
var RED_INTENCITY_COEF = 0.2989;
var GREEN_INTENCITY_COEF = 0.5870;
var BLUE_INTENCITY_COEF = 0.1140;
var totalPixelsInImage = $image.shape[0] * $image.shape[1];
var $threshold = mul(tensor1d([threshValue]), 255);
var r, g, b, grayscale;
assert($image.rank === 3, function () {
return 'Error in threshold: image must be rank 3,' + ("but got rank " + $image.rank + ".");
});
assert($image.shape[2] === 3 || $image.shape[2] === 1, function () {
return 'Error in threshold: ' + 'image color channel must be equal to 3 or 1' + ("but got " + $image.shape[2] + ".");
});
assert($image.dtype === 'int32' || $image.dtype === 'float32', function () {
return 'Error in dtype: image dtype must be int32 or float32,' + ("but got dtype " + $image.dtype + ".");
});
assert(method === 'otsu' || method === 'binary', function () {
return "Method must be binary or otsu, but was " + method;
});
if ($image.shape[2] === 3) {
var _split = split$1($image, [1, 1, 1], -1);
r = _split[0];
g = _split[1];
b = _split[2];
var $r = mul(r, RED_INTENCITY_COEF);
var $g = mul(g, GREEN_INTENCITY_COEF);
var $b = mul(b, BLUE_INTENCITY_COEF);
grayscale = add$1(add$1($r, $g), $b);
} else {
grayscale = image;
}
if (method === 'otsu') {
var $histogram = bincount(cast(round$1(grayscale), 'int32'), tensor([]), 256);
$threshold = otsu($histogram, totalPixelsInImage);
}
var invCondition = inverted ? lessEqual(grayscale, $threshold) : greater(grayscale, $threshold);
var result = cast(mul(invCondition, 255), 'int32');
return result;
}
function otsu(histogram, total) {
var bestThresh = tensor1d([-1]);
var bestInBetVar = tensor1d([0]);
var cInBetVar = tensor1d([0]);
var classFirst, classSecond, meanFirst, meanSec, weightForeground, weightBack;
for (var index = 0; index < histogram.size - 1; index++) {
classFirst = slice$2(histogram, 0, index + 1);
classSecond = slice$2(histogram, index + 1);
weightForeground = div(sum$1(classFirst), total);
weightBack = div(sum$1(classSecond), total);
var meanFirstDivA = sum$1(mul(classFirst, range(0, classFirst.size)));
meanFirst = div(meanFirstDivA, sum$1(classFirst));
var meanSecFill = fill(classSecond.shape, classFirst.size);
var meanSecAdd = add$1(range(0, classSecond.size), meanSecFill);
var meanSecMul = mul(classSecond, meanSecAdd);
meanSec = div(sum$1(meanSecMul), sum$1(classSecond));
var cInBetVarSubA = sub(meanFirst, meanSec);
var cInBetVarSubB = sub(meanFirst, meanSec);
var cInBetVarMul = mul(weightForeground, weightBack);
cInBetVar = mul(mul(cInBetVarMul, cInBetVarSubA), cInBetVarSubB);
var condition = greater(cInBetVar, bestInBetVar);
bestInBetVar = where(condition, cInBetVar, bestInBetVar);
bestThresh = where(condition, tensor1d([index]), bestThresh);
}
return bestThresh;
}
var threshold = op({
threshold_: threshold_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Applies the given transform(s) to the image(s).
*
* @param image 4d tensor of shape `[batch, imageHeight, imageWidth, depth]`.
* @param transforms Projective transform matrix/matrices. A tensor1d of length
* 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0
* b1, b2, c0, c1], then it maps the output point (x, y) to a transformed
* input point (x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k),
* where k = c0 x + c1 y + 1. The transforms are inverted compared to the
* transform mapping input points to output points.
* @param interpolation Interpolation mode.
* Supported values: 'nearest', 'bilinear'. Default to 'nearest'.
* @param fillMode Points outside the boundaries of the input are filled
* according to the given mode, one of 'constant', 'reflect', 'wrap',
* 'nearest'. Default to 'constant'.
* 'reflect': (d c b a | a b c d | d c b a ) The input is extended by
* reflecting about the edge of the last pixel.
* 'constant': (k k k k | a b c d | k k k k) The input is extended by
* filling all values beyond the edge with the same constant value k.
* 'wrap': (a b c d | a b c d | a b c d) The input is extended by
* wrapping around to the opposite edge.
* 'nearest': (a a a a | a b c d | d d d d) The input is extended by
* the nearest pixel.
* @param fillValue A float represents the value to be filled outside the
* boundaries when fillMode is 'constant'.
* @param Output dimension after the transform, [height, width]. If undefined,
* output is the same size as input image.
*
* @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'}
*/
function transform_(image, transforms, interpolation, fillMode, fillValue, outputShape) {
if (interpolation === void 0) {
interpolation = 'nearest';
}
if (fillMode === void 0) {
fillMode = 'constant';
}
if (fillValue === void 0) {
fillValue = 0;
}
var $image = convertToTensor(image, 'image', 'transform', 'float32');
var $transforms = convertToTensor(transforms, 'transforms', 'transform', 'float32');
assert($image.rank === 4, function () {
return 'Error in transform: image must be rank 4,' + ("but got rank " + $image.rank + ".");
});
assert($transforms.rank === 2 && ($transforms.shape[0] === $image.shape[0] || $transforms.shape[0] === 1) && $transforms.shape[1] === 8, function () {
return "Error in transform: Input transform should be batch x 8 or 1 x 8";
});
assert(outputShape == null || outputShape.length === 2, function () {
return 'Error in transform: outputShape must be [height, width] or null, ' + ("but got " + outputShape + ".");
});
var inputs = {
image: $image,
transforms: $transforms
};
var attrs = {
interpolation: interpolation,
fillMode: fillMode,
fillValue: fillValue,
outputShape: outputShape
};
return ENGINE.runKernel(Transform, inputs, attrs);
}
var transform = op({
transform_: transform_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Copy a tensor setting everything outside a central band in each innermost
* matrix to zero.
*
* The band part is computed as follows: Assume input has `k` dimensions
* `[I, J, K, ..., M, N]`, then the output is a tensor with the same shape where
* `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
* The indicator function
* `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower))`
* `&& (num_upper < 0 || (n-m) <= num_upper)`
*
* ```js
* const x = tf.tensor2d([[ 0, 1, 2, 3],
* [-1, 0, 1, 2],
* [-2, -1, 0, 1],
* [-3, -2, -1, 0]]);
* let y = tf.linalg.bandPart(x, 1, -1);
* y.print(); // [[ 0, 1, 2, 3],
* // [-1, 0, 1, 2],
* // [ 0, -1, 0, 1],
* // [ 0, 0 , -1, 0]]
* let z = tf.linalg.bandPart(x, 2, 1);
* z.print(); // [[ 0, 1, 0, 0],
* // [-1, 0, 1, 0],
* // [-2, -1, 0, 1],
* // [ 0, -2, -1, 0]]
* ```
*
* @param x Rank `k` tensor
* @param numLower Number of subdiagonals to keep.
* If negative, keep entire lower triangle.
* @param numUpper Number of subdiagonals to keep.
* If negative, keep entire upper triangle.
* @returns Rank `k` tensor of the same shape as input.
* The extracted banded tensor.
*
* @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
*/
function bandPart_(a, numLower, numUpper) {
assert(numLower % 1 === 0, function () {
return "bandPart(): numLower must be an integer, got " + numLower + ".";
});
assert(numUpper % 1 === 0, function () {
return "bandPart(): numUpper must be an integer, got " + numUpper + ".";
});
var $a = convertToTensor(a, 'a', 'bandPart');
assert($a.rank >= 2, function () {
return "bandPart(): Rank must be at least 2, got " + $a.rank + ".";
});
var shape = $a.shape;
var _$a$shape$slice = $a.shape.slice(-2),
M = _$a$shape$slice[0],
N = _$a$shape$slice[1];
if (!(numLower <= M)) {
throw new Error("bandPart(): numLower (" + numLower + ")" + (" must not be greater than the number of rows (" + M + ")."));
}
if (!(numUpper <= N)) {
throw new Error("bandPart(): numUpper (" + numUpper + ")" + (" must not be greater than the number of columns (" + N + ")."));
}
if (numLower < 0) {
numLower = M;
}
if (numUpper < 0) {
numUpper = N;
}
var i = reshape(range(0, M, 1, 'int32'), [-1, 1]);
var j = range(0, N, 1, 'int32');
var ij = sub(i, j);
var inBand = logicalAnd(lessEqual(ij, scalar(+numLower, 'int32')), greaterEqual(ij, scalar(-numUpper, 'int32')));
var zero = zeros([M, N], $a.dtype);
return reshape(stack(unstack(reshape($a, [-1, M, N])).map(function (mat) {
return where(inBand, mat, zero);
})), shape);
}
var bandPart = op({
bandPart_: bandPart_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Gram-Schmidt orthogonalization.
*
* ```js
* const x = tf.tensor2d([[1, 2], [3, 4]]);
* let y = tf.linalg.gramSchmidt(x);
* y.print();
* console.log('Othogonalized:');
* y.dot(y.transpose()).print(); // should be nearly the identity matrix.
* console.log('First row direction maintained:');
* const data = await y.array();
* console.log(data[0][1] / data[0][0]); // should be nearly 2.
* ```
*
* @param xs The vectors to be orthogonalized, in one of the two following
* formats:
* - An Array of `tf.Tensor1D`.
* - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows
* of `xs`.
* In each case, all the vectors must have the same length and the length
* must be greater than or equal to the number of vectors.
* @returns The orthogonalized and normalized vectors or matrix.
* Orthogonalization means that the vectors or the rows of the matrix
* are orthogonal (zero inner products). Normalization means that each
* vector or each row of the matrix has an L2 norm that equals `1`.
*
* @doc {heading:'Operations', subheading:'Linear Algebra', namespace:'linalg'}
*/
function gramSchmidt_(xs) {
var inputIsTensor2D;
if (Array.isArray(xs)) {
(function () {
inputIsTensor2D = false;
assert(xs != null && xs.length > 0, function () {
return 'Gram-Schmidt process: input must not be null, undefined, or ' + 'empty';
});
var dim = xs[0].shape[0];
var _loop = function _loop(i) {
assert(xs[i].shape[0] === dim, function () {
return 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + ("(" + xs[i].shape[0] + " vs. " + dim + ")");
});
};
for (var i = 1; i < xs.length; ++i) {
_loop(i);
}
})();
} else {
inputIsTensor2D = true;
xs = split$1(xs, xs.shape[0], 0).map(function (x) {
return squeeze(x, [0]);
});
}
assert(xs.length <= xs[0].shape[0], function () {
return "Gram-Schmidt: Number of vectors (" + xs.length + ") exceeds " + ("number of dimensions (" + xs[0].shape[0] + ").");
});
var ys = [];
var xs1d = xs;
var _loop2 = function _loop2(i) {
ys.push(ENGINE.tidy(function () {
var x = xs1d[i];
if (i > 0) {
for (var j = 0; j < i; ++j) {
var proj = mul(sum$1(mul(ys[j], x)), ys[j]);
x = sub(x, proj);
}
}
return div(x, norm(x, 'euclidean'));
}));
};
for (var i = 0; i < xs.length; ++i) {
_loop2(i);
}
if (inputIsTensor2D) {
return stack(ys, 0);
} else {
return ys;
}
}
var gramSchmidt = op({
gramSchmidt_: gramSchmidt_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Compute QR decomposition of m-by-n matrix using Householder transformation.
*
* Implementation based on
* [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf]
* (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf)
*
* ```js
* const a = tf.tensor2d([[1, 2], [3, 4]]);
* let [q, r] = tf.linalg.qr(a);
* console.log('Q');
* q.print();
* console.log('R');
* r.print();
* console.log('Orthogonalized');
* q.dot(q.transpose()).print() // should be nearly the identity matrix.
* console.log('Reconstructed');
* q.dot(r).print(); // should be nearly [[1, 2], [3, 4]];
* ```
*
* @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose
* it has the shape `[..., M, N]`.
* @param fullMatrices An optional boolean parameter. Defaults to `false`.
* If `true`, compute full-sized `Q`. If `false` (the default),
* compute only the leading N columns of `Q` and `R`.
* @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix,
* i.e., its columns all have unit norm and are mutually orthogonal.
* If `M >= N`,
* If `fullMatrices` is `false` (default),
* - `Q` has a shape of `[..., M, N]`,
* - `R` has a shape of `[..., N, N]`.
* If `fullMatrices` is `true` (default),
* - `Q` has a shape of `[..., M, M]`,
* - `R` has a shape of `[..., M, N]`.
* If `M < N`,
* - `Q` has a shape of `[..., M, M]`,
* - `R` has a shape of `[..., M, N]`.
* @throws If the rank of `x` is less than 2.
*
* @doc {heading:'Operations',
* subheading:'Linear Algebra',
* namespace:'linalg'}
*/
function qr_(x, fullMatrices) {
if (fullMatrices === void 0) {
fullMatrices = false;
}
assert(x.rank >= 2, function () {
return "qr() requires input tensor to have a rank >= 2, but got rank " + x.rank;
});
if (x.rank === 2) {
return qr2d(x, fullMatrices);
} else {
// Rank > 2.
// TODO(cais): Below we split the input into individual 2D tensors,
// perform QR decomposition on them and then stack the results back
// together. We should explore whether this can be parallelized.
var outerDimsProd = x.shape.slice(0, x.shape.length - 2).reduce(function (value, prev) {
return value * prev;
});
var x2ds = unstack(reshape(x, [outerDimsProd, x.shape[x.shape.length - 2], x.shape[x.shape.length - 1]]), 0);
var q2ds = [];
var r2ds = [];
x2ds.forEach(function (x2d) {
var _qr2d = qr2d(x2d, fullMatrices),
q2d = _qr2d[0],
r2d = _qr2d[1];
q2ds.push(q2d);
r2ds.push(r2d);
});
var q = reshape(stack(q2ds, 0), x.shape);
var r = reshape(stack(r2ds, 0), x.shape);
return [q, r];
}
}
function qr2d(x, fullMatrices) {
if (fullMatrices === void 0) {
fullMatrices = false;
}
return ENGINE.tidy(function () {
assert(x.shape.length === 2, function () {
return "qr2d() requires a 2D Tensor, but got a " + x.shape.length + "D Tensor.";
});
var m = x.shape[0];
var n = x.shape[1];
var q = eye(m); // Orthogonal transform so far.
var r = clone(x); // Transformed matrix so far.
var one2D = tensor2d([[1]], [1, 1]);
var w = clone(one2D);
var iters = m >= n ? n : m;
var _loop = function _loop(j) {
// This tidy within the for-loop ensures we clean up temporary
// tensors as soon as they are no longer needed.
var rTemp = r;
var wTemp = w;
var qTemp = q;
var _ENGINE$tidy = ENGINE.tidy(function () {
// Find H = I - tau * w * w', to put zeros below R(j, j).
var rjEnd1 = slice$2(r, [j, j], [m - j, 1]);
var normX = norm(rjEnd1);
var rjj = slice$2(r, [j, j], [1, 1]); // The sign() function returns 0 on 0, which causes division by zero.
var s = where(greater(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
var u1 = sub(rjj, mul(s, normX));
var wPre = div(rjEnd1, u1);
if (wPre.shape[0] === 1) {
w = clone(one2D);
} else {
w = concat([one2D, slice$2(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])], 0);
}
var tau = neg(div(matMul(s, u1), normX)); // -- R := HR, Q := QH.
var rjEndAll = slice$2(r, [j, 0], [m - j, n]);
var tauTimesW = mul(tau, w);
var wT = transpose(w);
if (j === 0) {
r = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
} else {
var rTimesTau = sub(rjEndAll, matMul(tauTimesW, matMul(wT, rjEndAll)));
r = concat([slice$2(r, [0, 0], [j, n]), rTimesTau], 0);
}
var tawTimesWT = transpose(tauTimesW);
var qAllJEnd = slice$2(q, [0, j], [m, q.shape[1] - j]);
if (j === 0) {
q = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
} else {
var qTimesTau = sub(qAllJEnd, matMul(matMul(qAllJEnd, w), tawTimesWT));
q = concat([slice$2(q, [0, 0], [m, j]), qTimesTau], 1);
}
return [w, r, q];
});
w = _ENGINE$tidy[0];
r = _ENGINE$tidy[1];
q = _ENGINE$tidy[2];
dispose([rTemp, wTemp, qTemp]);
};
for (var j = 0; j < iters; ++j) {
_loop(j);
}
if (!fullMatrices && m > n) {
q = slice$2(q, [0, 0], [m, n]);
r = slice$2(r, [0, 0], [n, n]);
}
return [q, r];
});
}
var qr = op({
qr_: qr_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
(function (Reduction) {
Reduction[Reduction["NONE"] = 0] = "NONE";
Reduction[Reduction["MEAN"] = 1] = "MEAN";
Reduction[Reduction["SUM"] = 2] = "SUM";
Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS";
})(exports.Reduction || (exports.Reduction = {}));
/**
* Computes the weighted loss between two tensors.
*
* @param losses Tensor of shape `[batch_size, d1, ... dN]`.
* @param weights Tensor whose rank is either 0, or the same rank as
* `losses`, and must be broadcastable to `losses` (i.e., all
* dimensions must be either `1`, or the same as the corresponding
* `losses` dimension).
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function computeWeightedLoss_(losses, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'computeWeightedLoss');
}
var weightedLoss = $weights == null ? $losses : mul($losses, $weights);
if (reduction === exports.Reduction.NONE) {
return weightedLoss;
}
if (reduction === exports.Reduction.SUM) {
return sum$1(weightedLoss);
}
if (reduction === exports.Reduction.MEAN) {
if ($weights == null) {
return mean(weightedLoss);
} else {
var broadcastFactor = $losses.size / $weights.size;
var result = div(sum$1(weightedLoss), sum$1($weights));
return broadcastFactor > 1 ? div(result, scalar(broadcastFactor)) : result;
}
}
if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) {
if ($weights == null) {
return div(sum$1(weightedLoss), scalar($losses.size));
} else {
var broadcastedWeights = mul($weights, ones$1($losses.shape));
var numNonZeros = cast(sum$1(notEqual(broadcastedWeights, scalar(0))), 'float32');
return div(sum$1(weightedLoss), numNonZeros);
}
}
throw Error("Unknown reduction: " + reduction);
}
var computeWeightedLoss = op({
computeWeightedLoss_: computeWeightedLoss_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the absolute difference loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function absoluteDifference_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'absoluteDifference');
var $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'absoluteDifference');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: ');
var losses = abs$8(sub($labels, $predictions));
return computeWeightedLoss(losses, $weights, reduction);
}
var absoluteDifference = op({
absoluteDifference_: absoluteDifference_
});
/**
* Computes the cosine distance loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param axis The dimension along which the cosine distance is computed.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function cosineDistance_(labels, predictions, axis, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'cosineDistance');
var $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'cosineDistance');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: ');
var one = scalar(1);
var losses = sub(one, sum$1(mul($labels, $predictions), axis, true));
return computeWeightedLoss(losses, $weights, reduction);
}
var cosineDistance = op({
cosineDistance_: cosineDistance_
});
/**
* Computes the Hinge loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function hingeLoss_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'hingeLoss');
var $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'hingeLoss');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: ');
var one = scalar(1); // Convert binary labels to (-1, 1)
$labels = sub(mul(scalar(2), $labels), one);
var losses = relu(sub(one, mul($labels, $predictions)));
return computeWeightedLoss(losses, $weights, reduction);
}
var hingeLoss = op({
hingeLoss_: hingeLoss_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the huber loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param delta Point where huber loss changes from quadratic to linear.
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`.
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function huberLoss_(labels, predictions, weights, delta, reduction) {
if (delta === void 0) {
delta = 1.0;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'huberLoss');
var $predictions = convertToTensor(predictions, 'predictions', 'huberLoss');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'huberLoss');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: ');
var deltaScalar = scalar(delta);
var error = abs$8(sub($predictions, $labels));
var quadratic = minimum(error, deltaScalar);
var linear = sub(error, quadratic);
var losses = add$1(mul(scalar(0.5), square(quadratic)), mul(deltaScalar, linear));
return computeWeightedLoss(losses, $weights, reduction);
}
var huberLoss = op({
huberLoss_: huberLoss_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the log loss between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param epsilon A small increment to avoid taking log of zero
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function logLoss_(labels, predictions, weights, epsilon, reduction) {
if (epsilon === void 0) {
epsilon = 1e-7;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'logLoss');
var $predictions = convertToTensor(predictions, 'predictions', 'logLoss');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'logLoss');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: ');
var one = scalar(1);
var epsilonScalar = scalar(epsilon);
var l1 = neg(mul($labels, log$a(add$1($predictions, epsilonScalar))));
var l2 = mul(sub(one, $labels), log$a(add$1(sub(one, $predictions), epsilonScalar)));
var losses = sub(l1, l2);
return computeWeightedLoss(losses, $weights, reduction);
}
var logLoss = op({
logLoss_: logLoss_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the mean squared error between two tensors.
*
* @param labels The ground truth output tensor, same dimensions as
* 'predictions'.
* @param predictions The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'}
*/
function meanSquaredError_(labels, predictions, weights, reduction) {
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $labels = convertToTensor(labels, 'labels', 'meanSquaredError');
var $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'meanSquaredError');
}
assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: ');
var losses = squaredDifference($labels, $predictions);
return computeWeightedLoss(losses, $weights, reduction);
}
var meanSquaredError = op({
meanSquaredError_: meanSquaredError_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sigmoidCrossEntropyWithLogits_(labels, logits) {
var $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits');
var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits');
assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: ');
/**
* Implementation Details:
*
* For brevity, let `x = logits`, `z = labels`. The logistic loss is
* z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
* = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
* = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
* = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
* = (1 - z) * x + log(1 + exp(-x))
* = x - x * z + log(1 + exp(-x))
*
* For x < 0, to avoid overflow in exp(-x), we reformulate the above
* x - x * z + log(1 + exp(-x))
* = log(exp(x)) - x * z + log(1 + exp(-x))
* = - x * z + log(1 + exp(x))
*
* Hence, to ensure stability and avoid overflow, the implementation uses
* this equivalent formulation:
* max(x, 0) - x * z + log(1 + exp(-abs(x)))
*/
var maxOutput = relu($logits);
var outputXTarget = mul($logits, $labels);
var sigmoidOutput = log1p(exp$3(neg(abs$8($logits))));
return add$1(sub(maxOutput, outputXTarget), sigmoidOutput);
}
/**
* Computes the sigmoid cross entropy loss between two tensors.
*
* If labelSmoothing is nonzero, smooth the labels towards 1/2:
*
* newMulticlassLabels = multiclassLabels * (1 - labelSmoothing)
* + 0.5 * labelSmoothing
*
* @param multiClassLabels The ground truth output tensor of shape
* [batch_size, num_classes], same dimensions as 'predictions'.
* @param logits The predicted outputs.
* @param weights Tensor whose rank is either 0, or the same rank as
* `labels`, and must be broadcastable to `labels` (i.e., all dimensions
* must be either `1`, or the same as the corresponding `losses`
* dimension).
* @param labelSmoothing If greater than 0, then smooth the labels.
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
*/
function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing, reduction) {
if (labelSmoothing === void 0) {
labelSmoothing = 0;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy');
var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy');
}
assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: ');
if (labelSmoothing > 0) {
var labelSmoothingScalar = scalar(labelSmoothing);
var one = scalar(1);
var half = scalar(0.5);
$multiClassLabels = add$1(mul($multiClassLabels, sub(one, labelSmoothingScalar)), mul(half, labelSmoothingScalar));
}
var losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits);
return computeWeightedLoss(losses, $weights, reduction);
}
var sigmoidCrossEntropy = op({
sigmoidCrossEntropy_: sigmoidCrossEntropy_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes softmax cross entropy between logits and labels.
*
* Measures the probability error in discrete classification tasks in which
* the classes are mutually exclusive (each entry is in exactly one class).
* For example, each CIFAR-10 image is labeled with one and only one label: an
* image can be a dog or a truck, but not both.
*
* `NOTE`: While the classes are mutually exclusive, their probabilities need
* not be. All that is required is that each row of labels is a valid
* probability distribution. If they are not, the computation of the gradient
* will be incorrect.
*
* `WARNING`: This op expects unscaled logits, since it performs a softmax on
* logits internally for efficiency. Do not call this op with the output of
* softmax, as it will produce incorrect results.
*
* logits and labels must have the same shape, e.g. [batch_size, num_classes]
* and the same dtype.
* @param labels The labels array.
* @param logits The logits array.
* @param dim The dimension softmax would be performed on. Defaults to `-1`
* which indicates the last dimension.
*/
function softmaxCrossEntropyWithLogits_(labels, logits, dim) {
if (dim === void 0) {
dim = -1;
}
if (dim === -1) {
dim = logits.rank - 1;
}
if (dim !== logits.rank - 1) {
throw Error("Softmax cross entropy along a non-last dimension is not yet " + ("supported. Labels / logits was rank " + logits.rank + " ") + ("and dim was " + dim));
} // Use a custom gradient for numerical stability.
var customOp = customGrad(function (labels, logits, save) {
// Reference:
// 1. http://cs231n.github.io/linear-classify/#softmax
// 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/
var keepDims = true;
var lse = logSumExp(logits, [dim], keepDims);
var logResult = sub(cast(logits, 'float32'), lse);
save([labels, logResult]);
var costVector = neg(mul(logResult, labels));
var value = sum$1(costVector, [dim]);
var gradFunc = function gradFunc(dy, saved) {
var labels = saved[0],
logResult = saved[1];
var dyShape = expandShapeToKeepDim(dy.shape, [dim]);
return [mul(reshape(dy, dyShape), sub(cast(labels, 'float32'), exp$3(logResult))), mul(reshape(dy, dyShape), sub(exp$3(logResult), cast(labels, 'float32')))];
};
return {
value: value,
gradFunc: gradFunc
};
});
return customOp(labels, logits);
}
/**
* Computes the softmax cross entropy loss between two tensors.
*
* If labelSmoothing is nonzero, smooth the labels towards 1/2:
*
* newOnehotLabels = onehotLabels * (1 - labelSmoothing)
* + labelSmoothing / numClasses
*
* @param onehotLabels One hot encoded labels
* [batch_size, num_classes], same dimensions as 'predictions'.
* @param logits The predicted outputs.
* @param weights Tensor whose rank is either 0, or 1, and must be
* broadcastable to `loss` of shape [batch_size]
* @param labelSmoothing If greater than 0, then smooth the labels.
* @param reduction Type of reduction to apply to loss. Should be of type
* `Reduction`
*
* @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' }
*/
function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing, reduction) {
if (labelSmoothing === void 0) {
labelSmoothing = 0;
}
if (reduction === void 0) {
reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS;
}
var $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy');
var $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy');
var $weights = null;
if (weights != null) {
$weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy');
}
assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: ');
if (labelSmoothing > 0) {
var labelSmoothingScalar = scalar(labelSmoothing);
var one = scalar(1);
var numClasses = scalar($onehotLabels.shape[1]);
$onehotLabels = add$1(mul($onehotLabels, sub(one, labelSmoothingScalar)), div(labelSmoothingScalar, numClasses));
}
var losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits);
return computeWeightedLoss(losses, $weights, reduction);
}
var softmaxCrossEntropy = op({
softmaxCrossEntropy_: softmaxCrossEntropy_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* The input SparseTensor is represented via the map of inputs {`indices`,
* `values`, `denseShape`}. The output SparseTensor has the same `denseShape`
* but with indices `outputIndices` and values `outputValues`. This op inserts a
* single entry for every row that doesn't have any values. The index is created
* as `[row, 0, ..., 0]` and the inserted value is `defaultValue`.
*
* For example, suppose `spInput` has shape [5, 6] and non-empty values:
* [0, 1]: a
* [0, 3]: b
* [2, 0]: c
* [3, 1]: d
*
* Rows 1 and 4 are empty, so the output will be of shape [5, 6] with values:
* [0, 1]: a
* [0, 3]: b
* [1, 0]: `defaultValue`
* [2, 0]: c
* [3, 1]: d
* [4, 0]: `defaultValue`
*
* The output SparseTensor will be in row-major order and will have the same
* shape as the input.
*
* This op also returns an indicator vector shaped [dense_shape[0]] such that
* emptyRowIndicator[i] = True iff row i was an empty row.
*
* And a reverse index map vector shaped [indices.shape[0]] that is used during
* backpropagation, reverseIndexMap[i] = outi s.t. indices[i, j] ==
* outputIndices[outi, j] for all j
*
* ```js
* const result = tf.sparse.sparseFillEmptyRows(
* [[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], [3, 3]],
* [0, 10, 13, 14, 32, 33], [5, 6], -1);
* console.log(result);
* result['outputIndices'].print(); // [[0, 0], [1, 0], [1, 3], [1, 4],
* // [2, 0], [3, 2], [3, 3], [4, 0]]
* result['outputValues'].print(); // [0, 10, 13, 14,-1, 32, 33, -1]
* result['emptyRowIndicator'].print(); // [false, false, true, false, true]
* result['reverseIndexMap'].print(); // [0, 1, 2, 3, 5, 6]
* ```
* @param indices: 2-D. the indices of the sparse tensor.
* @param values: 1-D. the values of the sparse tensor.
* @param denseShape: 1-D. the shape of the sparse tensor.
* @param defaultValue: 0-D. default value to insert into location [row, 0, ...,
* 0] for rows missing from the input sparse tensor.
* @return A map with the following properties:
* - outputIndices
* - outputValues: 1-D. the values of the filled sparse tensor.
* - emptyRowIndicator: 1-D. whether the dense row was missing in the input
* sparse tensor.
* - reverseIndexMap: 1-D. a map from the input indices to the output
* indices.
* @doc {heading: 'Operations', subheading: 'Sparse'}
*/
function sparseFillEmptyRows_(indices, values, denseShape, defaultValue) {
var $indices = convertToTensor(indices, 'indices', 'sparseFillEmptyRows');
var $values = convertToTensor(values, 'values', 'sparseFillEmptyRows');
var $denseShape = convertToTensor(denseShape, 'denseShape', 'sparseFillEmptyRows');
var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseFillEmptyRows', $values.dtype);
if ($indices.rank !== 2) {
throw new Error("Indices should be Tensor2D but received shape\n " + $indices.shape);
}
if ($values.rank !== 1) {
throw new Error("Values should be Tensor1D but received shape " + $values.shape);
}
if ($denseShape.rank !== 1) {
throw new Error("Dense shape should be Tensor1D but received shape " + $denseShape.shape);
}
if ($defaultValue.rank !== 0) {
throw new Error("Default value should be a scalar but received shape " + $defaultValue.shape);
}
var inputs = {
indices: $indices,
values: $values,
denseShape: $denseShape,
defaultValue: $defaultValue
};
var result = ENGINE.runKernel(SparseFillEmptyRows, inputs);
return {
outputIndices: result[0],
outputValues: result[1],
emptyRowIndicator: result[2],
reverseIndexMap: result[3]
};
}
var sparseFillEmptyRows = op({
sparseFillEmptyRows_: sparseFillEmptyRows_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* This operation has the same semantics as reshape on the represented dense
* tensor. The `inputIndices` are recomputed based on the requested `newShape`.
* If one component of `newShape` is the special value -1, the size of that
* dimension is computed so that the total dense size remains constant. At most
* one component of `newShape` can be -1. The number of dense elements implied
* by `newShape` must be the same as the number of dense elements originally
* implied by `inputShape`. Reshaping does not affect the order of values in the
* SparseTensor. If the input tensor has rank R_in and N non-empty values, and
* `newShape` has length R_out, then `inputIndices` has shape [N, R_in],
* `inputShape` has length R_in, `outputIndices` has shape [N, R_out], and
* `outputShape` has length R_out.
*
* ```js
* const result = tf.sparse.sparseReshape(
* [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 2, 3]],
* [2, 3, 6], [9, -1]);
* console.log(result);
* result['outputIndices'].print(); //[[0, 0], [0, 1], [1, 2], [4, 2], [8, 1]]
* result['outputShape'].print(); // [9, 4]
* ```
* @param inputIndices: 2-D. N x R_in matrix with the indices of non-empty
* values in a SparseTensor.
* @param inputShape: 1-D. R_in Tensor1D with the input SparseTensor's dense
* shape.
* @param newShape: 1-D. R_out Tensor1D with the requested new dense shape.
* @return A map with the following properties:
* - outputIndices: 2-D. N x R_out matrix with the updated indices of
* non-empty values in the output SparseTensor.
* - outputShape: 1-D. R_out vector with the full dense shape of the output
* SparseTensor. This is the same as newShape but with any -1 dimensions
* filled in.
* @doc {heading: 'Operations', subheading: 'Sparse'}
*/
function sparseReshape_(inputIndices, inputShape, newShape) {
var $inputIndices = convertToTensor(inputIndices, 'inputIndices', 'sparseReshape');
var $inputShape = convertToTensor(inputShape, 'inputShape', 'sparseReshape');
var $newShape = convertToTensor(newShape, 'newShape', 'sparseReshape');
if ($inputIndices.rank !== 2) {
throw new Error("Input indices should be Tensor2D but received shape\n " + $inputIndices.shape);
}
if ($inputShape.rank !== 1) {
throw new Error("Input shape should be Tensor1D but received shape " + $inputShape.shape);
}
if ($newShape.rank !== 1) {
throw new Error("New shape should be Tensor1D but received shape " + $newShape.shape);
}
var inputs = {
inputIndices: $inputIndices,
inputShape: $inputShape,
newShape: $newShape
};
var result = ENGINE.runKernel(SparseReshape, inputs);
return {
outputIndices: result[0],
outputShape: result[1]
};
}
var sparseReshape = op({
sparseReshape_: sparseReshape_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the mean along sparse segments of a tensor.
*
* ```js
* const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
* // Select two rows, one segment.
* const result1 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 0], 'int32'));
* result1.print(); // [[0, 0, 0, 0]]
*
* // Select two rows, two segments.
* const result2 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 1], 'int32'));
* result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
*
* // Select all rows, two segments.
* const result3 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1, 2], 'int32'),
* tf.tensor1d([0, 1, 1], 'int32'));
* result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
* ```
* @param data: A Tensor of at least one dimension with data that will be
* assembled in the output.
* @param indices: A 1-D Tensor with indices into data. Has same rank as
* segmentIds.
* @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
* should be sorted and can be repeated.
* @return Has same shape as data, except for dimension 0 which has equal to
* the number of segments.
*
* @doc {heading: 'Operations', subheading: 'Sparse'}
*/
function sparseSegmentMean_(data, indices, segmentIds) {
var $data = convertToTensor(data, 'data', 'sparseSegmentMean');
var $indices = convertToTensor(indices, 'indices', 'sparseSegmentMean');
var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean');
if ($data.rank < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if ($indices.rank !== 1) {
throw new Error("Indices should be Tensor1D but received shape\n " + $indices.shape);
}
if ($segmentIds.rank !== 1) {
throw new Error("Segment ids should be Tensor1D but received shape\n " + $segmentIds.shape);
}
var inputs = {
data: $data,
indices: $indices,
segmentIds: $segmentIds
};
return ENGINE.runKernel(SparseSegmentMean, inputs);
}
var sparseSegmentMean = op({
sparseSegmentMean_: sparseSegmentMean_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the sum along sparse segments of a tensor.
*
* ```js
* const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]]);
* // Select two rows, one segment.
* const result1 = tf.sparse.sparseSegmentSum(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 0], 'int32'));
* result1.print(); // [[0, 0, 0, 0]]
*
* // Select two rows, two segment.
* const result2 = tf.sparse.sparseSegmentSum(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 1], 'int32'));
* result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
*
* // Select all rows, two segments.
* const result3 = tf.sparse.sparseSegmentSum(c,
* tf.tensor1d([0, 1, 2], 'int32'),
* tf.tensor1d([0, 0, 1], 'int32'));
* result3.print(); // [[0, 0, 0, 0], [5, 6, 7, 8]]
* ```
* @param data: A Tensor of at least one dimension with data that will be
* assembled in the output.
* @param indices: A 1-D Tensor with indices into data. Has same rank as
* segmentIds.
* @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
* should be sorted and can be repeated.
* @return Has same shape as data, except for dimension 0 which has equal to
* the number of segments.
*
* @doc {heading: 'Operations', subheading: 'Sparse'}
*/
function sparseSegmentSum_(data, indices, segmentIds) {
var $data = convertToTensor(data, 'data', 'sparseSegmentSum');
var $indices = convertToTensor(indices, 'indices', 'sparseSegmentSum');
var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum');
if ($data.rank < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if ($indices.rank !== 1) {
throw new Error("Indices should be Tensor1D but received shape\n " + $indices.shape);
}
if ($segmentIds.rank !== 1) {
throw new Error("Segment ids should be Tensor1D but received shape\n " + $segmentIds.shape);
}
var inputs = {
data: $data,
indices: $indices,
segmentIds: $segmentIds
};
return ENGINE.runKernel(SparseSegmentSum, inputs);
}
var sparseSegmentSum = op({
sparseSegmentSum_: sparseSegmentSum_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Creates ngrams from ragged string data.
*
* This op accepts a ragged tensor with 1 ragged dimension containing only
* strings and outputs a ragged tensor with 1 ragged dimension containing ngrams
* of that string, joined along the innermost axis.
*
* ```js
* const result = tf.string.stringNGrams(
* ['a', 'b', 'c', 'd'], tf.tensor1d([0, 2, 4], 'int32'),
* '|', [1, 2], 'LP', 'RP', -1, false);
* result['nGrams'].print(); // ['a', 'b', 'LP|a', 'a|b', 'b|RP',
* // 'c', 'd', 'LP|c', 'c|d', 'd|RP']
* result['nGramsSplits'].print(); // [0, 5, 10]
* ```
* @param data: The values tensor of the ragged string tensor to make ngrams out
* of. Must be a 1D string tensor.
* @param dataSplits: The splits tensor of the ragged string tensor to make
* ngrams out of.
* @param separator: The string to append between elements of the token. Use ""
* for no separator.
* @param nGramWidths: The sizes of the ngrams to create.
* @param leftPad: The string to use to pad the left side of the ngram sequence.
* Only used if pad_width !== 0.
* @param rightPad: The string to use to pad the right side of the ngram
* sequence. Only used if pad_width !== 0.
* @param padWidth: The number of padding elements to add to each side of each
* sequence. Note that padding will never be greater than `nGramWidths`-1
* regardless of this value. If `padWidth`=-1 , then add max(`nGramWidths)-1
* elements.
* @param preserveShortSequences: If true, then ensure that at least one ngram
* is generated for each input sequence. In particular, if an input sequence
* is shorter than min(ngramWidth) + 2*padWidth, then generate a single
* ngram containing the entire sequence. If false, then no ngrams are
* generated for these short input sequences.
* @return A map with the following properties:
* - nGrams: The values tensor of the output ngrams ragged tensor.
* - nGramsSplits: The splits tensor of the output ngrams ragged tensor.
*
* @doc {heading: 'Operations', subheading: 'String'}
*/
function stringNGrams_(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
var $data = convertToTensor(data, 'data', 'stringNGrams', 'string');
if ($data.dtype !== 'string') {
throw new Error('Data must be of datatype string');
}
if ($data.shape.length !== 1) {
throw new Error("Data must be a vector, saw: " + $data.shape);
}
var $dataSplits = convertToTensor(dataSplits, 'dataSplits', 'stringNGrams');
if ($dataSplits.dtype !== 'int32') {
throw new Error('Data splits must be of datatype int32');
}
var attrs = {
separator: separator,
nGramWidths: nGramWidths,
leftPad: leftPad,
rightPad: rightPad,
padWidth: padWidth,
preserveShortSequences: preserveShortSequences
};
var inputs = {
data: $data,
dataSplits: $dataSplits
};
var result = ENGINE.runKernel(StringNGrams, inputs, attrs);
return {
nGrams: result[0],
nGramsSplits: result[1]
};
}
var stringNGrams = op({
stringNGrams_: stringNGrams_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Split elements of `input` based on `delimiter` into a SparseTensor .
*
* Let N be the size of source (typically N will be the batch size). Split each
* element of `input` based on `delimiter` and return a SparseTensor containing
* the splitted tokens. Empty tokens are ignored if `skipEmpty` is set to True.
*
* `delimiter` can be empty, or a string of split characters. If `delimiter` is
* an empty string, each element of `input` is split into individual
* character strings. Otherwise every character of `delimiter` is a potential
* split point.
*
* ```js
* const result = tf.string.stringSplit(['hello world', 'a b c'], ' ');
* result['indices'].print(); // [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
* result['values'].print(); // ['hello', 'world', 'a', 'b', 'c']
* result['shape'].print(); // [2, 3]
* ```
* @param input: 1-D. Strings to split.
* @param delimiter: 0-D. Delimiter characters, or empty string.
* @param skipEmpty: Optional. If true, skip the empty strings from the result.
* Defaults to true.
* @return A map with the following properties:
* - indices: A dense matrix of int32 representing the indices of the sparse
* tensor.
* - values: A vector of strings corresponding to the splited values.
* - shape: a length-2 vector of int32 representing the shape of the sparse
* tensor, where the first value is N and the second value is the maximum number
* of tokens in a single input entry.
*
* @doc {heading: 'Operations', subheading: 'String'}
*/
function stringSplit_(input, delimiter, skipEmpty) {
if (skipEmpty === void 0) {
skipEmpty = true;
}
var $input = convertToTensor(input, 'input', 'stringSplit', 'string');
var $delimiter = convertToTensor(delimiter, 'delimiter', 'stringSplit', 'string');
if ($input.rank !== 1) {
throw new Error("Input should be Tensor1D but received shape " + $input.shape);
}
if ($delimiter.rank !== 0) {
throw new Error("Delimiter should be a scalar but received shape " + $delimiter.shape);
}
var attrs = {
skipEmpty: skipEmpty
};
var inputs = {
input: $input,
delimiter: $delimiter
};
var result = ENGINE.runKernel(StringSplit, inputs, attrs);
return {
indices: result[0],
values: result[1],
shape: result[2]
};
}
var stringSplit = op({
stringSplit_: stringSplit_
});
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts each string in the input Tensor to its hash mod by a number of
* buckets.
*
* The hash function is deterministic on the content of the string within the
* process and will never change. However, it is not suitable for cryptography.
* This function may be used when CPU time is scarce and inputs are trusted or
* unimportant. There is a risk of adversaries constructing inputs that all hash
* to the same bucket.
*
* ```js
* const result = tf.string.stringToHashBucketFast(
* ['Hello', 'TensorFlow', '2.x'], 3);
* result.print(); // [0, 2, 2]
* ```
* @param input: The strings to assign a hash bucket.
* @param numBuckets: The number of buckets.
* @return A Tensor of the same shape as the input tensor.
*
* @doc {heading: 'Operations', subheading: 'String'}
*/
function stringToHashBucketFast_(input, numBuckets) {
var $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
var attrs = {
numBuckets: numBuckets
};
if (numBuckets <= 0) {
throw new Error("Number of buckets must be at least 1");
}
var inputs = {
input: $input
};
return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
}
var stringToHashBucketFast = op({
stringToHashBucketFast_: stringToHashBucketFast_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var spectral = {
fft: fft,
ifft: ifft,
rfft: rfft,
irfft: irfft
};
var signal = {
hammingWindow: hammingWindow,
hannWindow: hannWindow,
frame: frame,
stft: stft
}; // Image Ops namespace
var image = {
flipLeftRight: flipLeftRight,
grayscaleToRGB: grayscaleToRGB,
resizeNearestNeighbor: resizeNearestNeighbor,
resizeBilinear: resizeBilinear,
rotateWithOffset: rotateWithOffset,
cropAndResize: cropAndResize,
nonMaxSuppression: nonMaxSuppression,
nonMaxSuppressionAsync: nonMaxSuppressionAsync,
nonMaxSuppressionWithScore: nonMaxSuppressionWithScore,
nonMaxSuppressionWithScoreAsync: nonMaxSuppressionWithScoreAsync,
nonMaxSuppressionPadded: nonMaxSuppressionPadded,
nonMaxSuppressionPaddedAsync: nonMaxSuppressionPaddedAsync,
threshold: threshold,
transform: transform
}; // linalg namespace
var linalg = {
bandPart: bandPart,
gramSchmidt: gramSchmidt,
qr: qr
}; // losses namespace;
var losses = {
absoluteDifference: absoluteDifference,
computeWeightedLoss: computeWeightedLoss,
cosineDistance: cosineDistance,
hingeLoss: hingeLoss,
huberLoss: huberLoss,
logLoss: logLoss,
meanSquaredError: meanSquaredError,
sigmoidCrossEntropy: sigmoidCrossEntropy,
softmaxCrossEntropy: softmaxCrossEntropy
};
var sparse = {
sparseFillEmptyRows: sparseFillEmptyRows,
sparseReshape: sparseReshape,
sparseSegmentMean: sparseSegmentMean,
sparseSegmentSum: sparseSegmentSum
};
var string = {
stringNGrams: stringNGrams,
stringSplit: stringSplit,
stringToHashBucketFast: stringToHashBucketFast
}; // Second level exports.
/** @doc {heading: 'Training', subheading: 'Classes', namespace: 'train'} */
var Optimizer = /*#__PURE__*/function (_Serializable) {
_inheritsLoose(Optimizer, _Serializable);
function Optimizer() {
return _Serializable.apply(this, arguments) || this;
}
var _proto = Optimizer.prototype;
/**
* Executes `f()` and minimizes the scalar output of `f()` by computing
* gradients of y with respect to the list of trainable variables provided by
* `varList`. If no list is provided, it defaults to all trainable variables.
*
* @param f The function to execute and whose output to minimize.
* @param returnCost Whether to return the scalar cost value produced by
* executing `f()`.
* @param varList An optional list of variables to update. If specified, only
* the trainable variables in varList will be updated by minimize. Defaults to
* all trainable variables.
*
* @doc {heading: 'Training', subheading: 'Optimizers'}
*/
_proto.minimize = function minimize(f, returnCost, varList) {
if (returnCost === void 0) {
returnCost = false;
}
var _this$computeGradient = this.computeGradients(f, varList),
value = _this$computeGradient.value,
grads = _this$computeGradient.grads;
if (varList != null) {
var gradArray = varList.map(function (v) {
return {
name: v.name,
tensor: grads[v.name]
};
});
this.applyGradients(gradArray);
} else {
this.applyGradients(grads);
} // Dispose gradients.
dispose(grads);
if (returnCost) {
return value;
} else {
value.dispose();
return null;
}
}
/**
* The number of iterations that this optimizer instance has been invoked for.
*/
;
_proto.incrementIterations = function incrementIterations() {
this.iterations_ = this.iterations + 1;
}
/**
* Executes f() and computes the gradient of the scalar output of f() with
* respect to the list of trainable variables provided by `varList`. If no
* list is provided, it defaults to all trainable variables.
*
* @param f The function to execute and whose output to use for computing
* gradients with respect to variables.
* @param varList An optional list of variables to compute gradients with
* respect to. If specified, only the trainable variables in varList will have
* gradients computed with respect to. Defaults to all trainable variables.
*
* @doc {heading: 'Training', subheading: 'Optimizers'}
*/
;
_proto.computeGradients = function computeGradients(f, varList) {
return variableGrads(f, varList);
}
/**
* Dispose the variables (if any) owned by this optimizer instance.
*/
;
_proto.dispose = function dispose$1() {
if (this.iterations_ != null) {
dispose(this.iterations_);
}
};
_proto.saveIterations = /*#__PURE__*/function () {
var _saveIterations = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return _context.abrupt("return", {
name: 'iter',
// TODO(cais): Use 'int64' type when available.
tensor: scalar(this.iterations_, 'int32')
});
case 2:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function saveIterations() {
return _saveIterations.apply(this, arguments);
}
return saveIterations;
}();
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
throw new Error('getWeights() is not implemented for this optimizer yet.');
case 1:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(weightValues) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
throw new Error("setWeights() is not implemented for this optimizer class " + ("" + this.getClassName()));
case 1:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}()
/**
* Extract the first element of the weight values and set it
* as the iterations counter variable of this instance of optimizer.
*
* @param weightValues
* @returns Weight values with the first element consumed and excluded.
*/
;
_proto.extractIterations =
/*#__PURE__*/
function () {
var _extractIterations = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(weightValues) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
_context4.next = 2;
return weightValues[0].tensor.data();
case 2:
this.iterations_ = _context4.sent[0];
return _context4.abrupt("return", weightValues.slice(1));
case 4:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function extractIterations(_x2) {
return _extractIterations.apply(this, arguments);
}
return extractIterations;
}();
_createClass(Optimizer, [{
key: "iterations",
get: function get() {
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return this.iterations_;
}
}]);
return Optimizer;
}(Serializable);
Object.defineProperty(Optimizer, Symbol.hasInstance, {
value: function value(instance) {
return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null;
}
});
/** @doclink Optimizer */
var AdadeltaOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(AdadeltaOptimizer, _Optimizer);
function AdadeltaOptimizer(learningRate, rho, epsilon) {
var _this;
if (epsilon === void 0) {
epsilon = null;
}
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.rho = rho;
_this.epsilon = epsilon;
_this.accumulatedGrads = [];
_this.accumulatedUpdates = [];
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
var _proto = AdadeltaOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this2.accumulatedGrads[i] == null) {
_this2.accumulatedGrads[i] = {
originalName: name + "/accum_grad",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
if (_this2.accumulatedUpdates[i] == null) {
_this2.accumulatedUpdates[i] = {
originalName: name + "/accum_var",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedGrad = _this2.accumulatedGrads[i].variable;
var accumulatedUpdate = _this2.accumulatedUpdates[i].variable;
tidy(function () {
var newAccumulatedGrad = add$1(mul(accumulatedGrad, _this2.rho), mul(square(gradient), 1 - _this2.rho));
var updates = mul(div(sqrt$3(add$1(accumulatedUpdate, _this2.epsilon)), sqrt$3(add$1(accumulatedGrad, _this2.epsilon))), gradient);
var newAccumulatedUpdate = add$1(mul(accumulatedUpdate, _this2.rho), mul(square(updates), 1 - _this2.rho));
accumulatedGrad.assign(newAccumulatedGrad);
accumulatedUpdate.assign(newAccumulatedUpdate);
var newValue = add$1(mul(updates, -_this2.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
if (this.accumulatedUpdates != null) {
dispose(this.accumulatedGrads.map(function (v) {
return v.variable;
}));
dispose(this.accumulatedUpdates.map(function (v) {
return v.variable;
}));
}
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var variables;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
// Order matters for Python compatibility.
variables = [].concat(this.accumulatedGrads, this.accumulatedUpdates);
_context.next = 3;
return this.saveIterations();
case 3:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
return {
name: v.originalName,
tensor: v.variable
};
})));
case 5:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
var variableCount, trainable;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
variableCount = weightValues.length / 2;
trainable = false;
this.accumulatedGrads = weightValues.slice(0, variableCount).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedUpdates = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
case 7:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'rho': this.rho,
'epsilon': this.epsilon
};
}
/** @nocollapse */
;
AdadeltaOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['rho'], config['epsilon']);
};
return AdadeltaOptimizer;
}(Optimizer);
/** @nocollapse */
AdadeltaOptimizer.className = 'Adadelta'; // Name matters for Python compatibility.
registerClass(AdadeltaOptimizer);
/** @doclink Optimizer */
var AdagradOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(AdagradOptimizer, _Optimizer);
function AdagradOptimizer(learningRate, initialAccumulatorValue) {
var _this;
if (initialAccumulatorValue === void 0) {
initialAccumulatorValue = 0.1;
}
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.initialAccumulatorValue = initialAccumulatorValue;
_this.accumulatedGrads = [];
return _this;
}
var _proto = AdagradOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
if (_this2.accumulatedGrads[i] == null) {
var trainable = false;
_this2.accumulatedGrads[i] = {
originalName: name + "/accumulator",
variable: tidy(function () {
return fill(value.shape, _this2.initialAccumulatorValue).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedGrad = _this2.accumulatedGrads[i].variable;
tidy(function () {
var newAccumulatedGrad = add$1(accumulatedGrad, square(gradient));
accumulatedGrad.assign(newAccumulatedGrad);
var newValue = add$1(mul(div(gradient, sqrt$3(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -_this2.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
if (this.accumulatedGrads != null) {
dispose(this.accumulatedGrads.map(function (v) {
return v.variable;
}));
}
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return this.saveIterations();
case 2:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0].concat(this.accumulatedGrads.map(function (v) {
return {
name: v.originalName,
tensor: v.variable
};
})));
case 4:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
var trainable;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
trainable = false;
this.accumulatedGrads = weightValues.map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
case 5:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'initialAccumulatorValue': this.initialAccumulatorValue
};
}
/** @nocollapse */
;
AdagradOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['initialAccumulatorValue']);
};
return AdagradOptimizer;
}(Optimizer);
/** @nocollapse */
AdagradOptimizer.className = 'Adagrad'; // Note: Name matters for Python compatibility.
registerClass(AdagradOptimizer);
var AdamOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(AdamOptimizer, _Optimizer);
function AdamOptimizer(learningRate, beta1, beta2, epsilon) {
var _this;
if (epsilon === void 0) {
epsilon = null;
}
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.beta1 = beta1;
_this.beta2 = beta2;
_this.epsilon = epsilon;
_this.accumulatedFirstMoment = [];
_this.accumulatedSecondMoment = [];
tidy(function () {
// accB* will be updated by batch.
_this.accBeta1 = scalar(beta1).variable();
_this.accBeta2 = scalar(beta2).variable();
});
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
var _proto = AdamOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) {
return v.name;
}) : Object.keys(variableGradients);
tidy(function () {
var oneMinusAccBeta1 = sub(1, _this2.accBeta1);
var oneMinusAccBeta2 = sub(1, _this2.accBeta2);
varNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this2.accumulatedFirstMoment[i] == null) {
_this2.accumulatedFirstMoment[i] = {
originalName: name + "/m",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
if (_this2.accumulatedSecondMoment[i] == null) {
_this2.accumulatedSecondMoment[i] = {
originalName: name + "/v",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var firstMoment = _this2.accumulatedFirstMoment[i].variable;
var secondMoment = _this2.accumulatedSecondMoment[i].variable;
var newFirstMoment = add$1(mul(firstMoment, _this2.beta1), mul(gradient, 1 - _this2.beta1));
var newSecondMoment = add$1(mul(secondMoment, _this2.beta2), mul(square(gradient), 1 - _this2.beta2));
var biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1);
var biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2);
firstMoment.assign(newFirstMoment);
secondMoment.assign(newSecondMoment);
var newValue = add$1(mul(div(biasCorrectedFirstMoment, add$1(sqrt$3(biasCorrectedSecondMoment), _this2.epsilon)), -_this2.learningRate), value);
value.assign(newValue);
});
_this2.accBeta1.assign(mul(_this2.accBeta1, _this2.beta1));
_this2.accBeta2.assign(mul(_this2.accBeta2, _this2.beta2));
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
this.accBeta1.dispose();
this.accBeta2.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(function (v) {
return v.variable;
}));
}
if (this.accumulatedSecondMoment != null) {
dispose(this.accumulatedSecondMoment.map(function (v) {
return v.variable;
}));
}
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var variables;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
// Order matters for Python compatibility.
variables = [].concat(this.accumulatedFirstMoment, this.accumulatedSecondMoment);
_context.next = 3;
return this.saveIterations();
case 3:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
return {
name: v.originalName,
tensor: v.variable
};
})));
case 5:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
var _this3 = this;
var variableCount, trainable;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
tidy(function () {
_this3.accBeta1.assign(pow$5(_this3.beta1, _this3.iterations_ + 1));
_this3.accBeta2.assign(pow$5(_this3.beta2, _this3.iterations_ + 1));
});
variableCount = weightValues.length / 2;
trainable = false;
this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
case 8:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'beta1': this.beta1,
'beta2': this.beta2,
'epsilon': this.epsilon
};
}
/** @nocollapse */
;
AdamOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
};
return AdamOptimizer;
}(Optimizer);
/** @nocollapse */
AdamOptimizer.className = 'Adam'; // Note: Name matters for Python compatibility.
registerClass(AdamOptimizer);
var AdamaxOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(AdamaxOptimizer, _Optimizer);
function AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay) {
var _this;
if (epsilon === void 0) {
epsilon = null;
}
if (decay === void 0) {
decay = 0.0;
}
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.beta1 = beta1;
_this.beta2 = beta2;
_this.epsilon = epsilon;
_this.decay = decay;
_this.accumulatedFirstMoment = [];
_this.accumulatedWeightedInfNorm = [];
tidy(function () {
_this.iteration = scalar(0).variable();
_this.accBeta1 = scalar(beta1).variable();
});
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
return _this;
}
var _proto = AdamaxOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
return item.name;
}) : Object.keys(variableGradients);
tidy(function () {
var oneMinusAccBeta1 = sub(1, _this2.accBeta1);
var lr = div(-_this2.learningRate, add$1(mul(_this2.iteration, _this2.decay), 1));
variableNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this2.accumulatedFirstMoment[i] == null) {
_this2.accumulatedFirstMoment[i] = {
originalName: name + "/m",
variable: zerosLike(value).variable(trainable)
};
}
if (_this2.accumulatedWeightedInfNorm[i] == null) {
_this2.accumulatedWeightedInfNorm[i] = {
originalName: name + "/v",
variable: zerosLike(value).variable(trainable)
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var firstMoment = _this2.accumulatedFirstMoment[i].variable;
var weightedInfNorm = _this2.accumulatedWeightedInfNorm[i].variable;
var newFirstMoment = add$1(mul(firstMoment, _this2.beta1), mul(gradient, 1 - _this2.beta1));
var ut0 = mul(weightedInfNorm, _this2.beta2);
var ut1 = abs$8(gradient);
var newWeightedInfNorm = maximum(ut0, ut1);
firstMoment.assign(newFirstMoment);
weightedInfNorm.assign(newWeightedInfNorm);
var newValue = add$1(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add$1(newWeightedInfNorm, _this2.epsilon))), value);
value.assign(newValue);
});
_this2.iteration.assign(add$1(_this2.iteration, 1));
_this2.accBeta1.assign(mul(_this2.accBeta1, _this2.beta1));
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
this.accBeta1.dispose();
this.iteration.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(function (v) {
return v.variable;
}));
}
if (this.accumulatedWeightedInfNorm != null) {
dispose(this.accumulatedWeightedInfNorm.map(function (v) {
return v.variable;
}));
}
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
throw new Error('getWeights() is not implemented for Adamax yet.');
case 1:
case "end":
return _context.stop();
}
}
}, _callee);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
throw new Error('setWeights() is not implemented for Adamax yet.');
case 1:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'beta1': this.beta1,
'beta2': this.beta2,
'epsilon': this.epsilon,
'decay': this.decay
};
}
/** @nocollapse */
;
AdamaxOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
};
return AdamaxOptimizer;
}(Optimizer);
/** @nocollapse */
AdamaxOptimizer.className = 'Adamax'; // Note: Name matters for Python compatbility.
registerClass(AdamaxOptimizer);
/** @doclink Optimizer */
var SGDOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(SGDOptimizer, _Optimizer);
function SGDOptimizer(learningRate) {
var _this;
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.setLearningRate(learningRate);
return _this;
}
var _proto = SGDOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var varNames = Array.isArray(variableGradients) ? variableGradients.map(function (v) {
return v.name;
}) : Object.keys(variableGradients);
varNames.forEach(function (name, i) {
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var value = ENGINE.registeredVariables[name];
tidy(function () {
var newValue = add$1(mul(_this2.c, gradient), value);
value.assign(newValue);
});
});
this.incrementIterations();
}
/**
* Sets the learning rate of the optimizer.
*/
;
_proto.setLearningRate = function setLearningRate(learningRate) {
this.learningRate = learningRate;
if (this.c != null) {
this.c.dispose();
}
this.c = keep(scalar(-learningRate));
};
_proto.dispose = function dispose() {
this.c.dispose();
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return this.saveIterations();
case 2:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0]);
case 4:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
if (!(weightValues.length !== 0)) {
_context2.next = 5;
break;
}
throw new Error('SGD optimizer does not have settable weights.');
case 5:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate
};
}
/** @nocollapse */
;
SGDOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate']);
};
return SGDOptimizer;
}(Optimizer);
/** @nocollapse */
SGDOptimizer.className = 'SGD'; // Note: Name matters for Python compatibility.
registerClass(SGDOptimizer);
/** @doclink Optimizer */
var MomentumOptimizer = /*#__PURE__*/function (_SGDOptimizer) {
_inheritsLoose(MomentumOptimizer, _SGDOptimizer);
function MomentumOptimizer(learningRate, momentum, useNesterov) {
var _this;
if (useNesterov === void 0) {
useNesterov = false;
}
_this = _SGDOptimizer.call(this, learningRate) || this;
_this.learningRate = learningRate;
_this.momentum = momentum;
_this.useNesterov = useNesterov;
_this.accumulations = [];
_this.m = scalar(_this.momentum);
return _this;
}
var _proto = MomentumOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
if (_this2.accumulations[i] == null) {
var trainable = false;
_this2.accumulations[i] = {
originalName: name + "/momentum",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
var accumulation = _this2.accumulations[i].variable;
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
tidy(function () {
var newValue;
var newAccumulation = add$1(mul(_this2.m, accumulation), gradient);
if (_this2.useNesterov) {
newValue = add$1(mul(_this2.c, add$1(gradient, mul(newAccumulation, _this2.m))), value);
} else {
newValue = add$1(mul(_this2.c, newAccumulation), value);
}
accumulation.assign(newAccumulation);
value.assign(newValue);
});
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
this.m.dispose();
if (this.accumulations != null) {
dispose(this.accumulations.map(function (v) {
return v.variable;
}));
}
}
/**
* Sets the momentum of the optimizer.
*
* @param momentum
*/
;
_proto.setMomentum = function setMomentum(momentum) {
this.momentum = momentum;
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return this.saveIterations();
case 2:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0].concat(this.accumulations.map(function (v) {
return {
name: v.originalName,
tensor: v.variable
};
})));
case 4:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
var trainable;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
trainable = false;
this.accumulations = weightValues.map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
case 5:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'momentum': this.momentum,
'useNesterov': this.useNesterov
};
}
/** @nocollapse */
;
MomentumOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
};
return MomentumOptimizer;
}(SGDOptimizer);
/** @nocollapse */
MomentumOptimizer.className = 'Momentum'; // Name matters for Python compatibility.
registerClass(MomentumOptimizer);
/** @doclink Optimizer */
var RMSPropOptimizer = /*#__PURE__*/function (_Optimizer) {
_inheritsLoose(RMSPropOptimizer, _Optimizer);
function RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered) {
var _this;
if (decay === void 0) {
decay = 0.9;
}
if (momentum === void 0) {
momentum = 0.0;
}
if (epsilon === void 0) {
epsilon = null;
}
if (centered === void 0) {
centered = false;
}
_this = _Optimizer.call(this) || this;
_this.learningRate = learningRate;
_this.decay = decay;
_this.momentum = momentum;
_this.epsilon = epsilon;
_this.accumulatedMeanSquares = [];
_this.accumulatedMoments = [];
_this.accumulatedMeanGrads = [];
_this.centered = centered;
if (epsilon == null) {
_this.epsilon = ENGINE.backend.epsilon();
}
if (learningRate == null) {
throw new Error("learningRate for RMSPropOptimizer must be defined.");
}
return _this;
}
var _proto = RMSPropOptimizer.prototype;
_proto.applyGradients = function applyGradients(variableGradients) {
var _this2 = this;
var variableNames = Array.isArray(variableGradients) ? variableGradients.map(function (item) {
return item.name;
}) : Object.keys(variableGradients);
variableNames.forEach(function (name, i) {
var value = ENGINE.registeredVariables[name];
var trainable = false;
if (_this2.accumulatedMeanSquares[i] == null) {
_this2.accumulatedMeanSquares[i] = {
originalName: name + "/rms",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
if (_this2.accumulatedMoments[i] == null) {
_this2.accumulatedMoments[i] = {
originalName: name + "/momentum",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
if (_this2.accumulatedMeanGrads[i] == null && _this2.centered) {
_this2.accumulatedMeanGrads[i] = {
originalName: name + "/mg",
variable: tidy(function () {
return zerosLike(value).variable(trainable);
})
};
}
var gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name];
if (gradient == null) {
return;
}
var accumulatedMeanSquare = _this2.accumulatedMeanSquares[i].variable;
var accumulatedMoments = _this2.accumulatedMoments[i].variable;
tidy(function () {
var newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, _this2.decay), mul(square(gradient), 1 - _this2.decay));
if (_this2.centered) {
var accumulatedMeanGrad = _this2.accumulatedMeanGrads[i].variable; // Centered gradient
var newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, _this2.decay), mul(gradient, 1 - _this2.decay));
var gradContribution = div(mul(gradient, _this2.learningRate), sqrt$3(sub(newAccumulatedMeanSquare, add$1(square(newAccumulatedMeanGrad), _this2.epsilon))));
var newAccumulatedMoments = add$1(mul(accumulatedMoments, _this2.momentum), gradContribution);
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
accumulatedMoments.assign(newAccumulatedMoments);
var newValue = sub(value, newAccumulatedMoments);
value.assign(newValue);
} else {
// Plain gradient
var _newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, _this2.decay), mul(square(gradient), 1 - _this2.decay));
var _newAccumulatedMoments = add$1(mul(accumulatedMoments, _this2.momentum), div(mul(gradient, _this2.learningRate), sqrt$3(add$1(_newAccumulatedMeanSquare, _this2.epsilon))));
accumulatedMeanSquare.assign(_newAccumulatedMeanSquare);
accumulatedMoments.assign(_newAccumulatedMoments);
var _newValue = sub(value, _newAccumulatedMoments);
value.assign(_newValue);
}
});
});
this.incrementIterations();
};
_proto.dispose = function dispose$1() {
if (this.accumulatedMeanSquares != null) {
dispose(this.accumulatedMeanSquares.map(function (v) {
return v.variable;
}));
}
if (this.accumulatedMeanGrads != null && this.centered) {
dispose(this.accumulatedMeanGrads.map(function (v) {
return v.variable;
}));
}
if (this.accumulatedMoments != null) {
dispose(this.accumulatedMoments.map(function (v) {
return v.variable;
}));
}
};
_proto.getWeights = /*#__PURE__*/function () {
var _getWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var variables;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
// Order matters for Python compatibility.
variables = [].concat(this.accumulatedMeanSquares, this.accumulatedMoments);
if (this.centered) {
variables.push.apply(variables, this.accumulatedMeanGrads);
}
_context.next = 4;
return this.saveIterations();
case 4:
_context.t0 = _context.sent;
return _context.abrupt("return", [_context.t0].concat(variables.map(function (v) {
return {
name: v.originalName,
tensor: v.variable
};
})));
case 6:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function getWeights() {
return _getWeights.apply(this, arguments);
}
return getWeights;
}();
_proto.setWeights = /*#__PURE__*/function () {
var _setWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(weightValues) {
var variableCount, trainable;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.extractIterations(weightValues);
case 2:
weightValues = _context2.sent;
variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
trainable = false;
this.accumulatedMeanSquares = weightValues.slice(0, variableCount).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
this.accumulatedMoments = weightValues.slice(variableCount, variableCount * 2).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
if (this.centered) {
this.accumulatedMeanGrads = weightValues.slice(variableCount * 2, variableCount * 3).map(function (v) {
return {
originalName: v.name,
variable: v.tensor.variable(trainable)
};
});
}
case 8:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setWeights(_x) {
return _setWeights.apply(this, arguments);
}
return setWeights;
}();
_proto.getConfig = function getConfig() {
return {
'learningRate': this.learningRate,
'decay': this.decay,
'momentum': this.momentum,
'epsilon': this.epsilon,
'centered': this.centered
};
}
/** @nocollapse */
;
RMSPropOptimizer.fromConfig = function fromConfig(cls, config) {
return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
};
return RMSPropOptimizer;
}(Optimizer);
/** @nocollapse */
RMSPropOptimizer.className = 'RMSProp'; // Note: Name matters for Python compatibility.
registerClass(RMSPropOptimizer);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OptimizerConstructors = /*#__PURE__*/function () {
function OptimizerConstructors() {}
/**
* Constructs a `tf.SGDOptimizer` that uses stochastic gradient descent.
*
* ```js
* // Fit a quadratic function by learning the coefficients a, b, c.
* const xs = tf.tensor1d([0, 1, 2, 3]);
* const ys = tf.tensor1d([1.1, 5.9, 16.8, 33.9]);
*
* const a = tf.scalar(Math.random()).variable();
* const b = tf.scalar(Math.random()).variable();
* const c = tf.scalar(Math.random()).variable();
*
* // y = a * x^2 + b * x + c.
* const f = x => a.mul(x.square()).add(b.mul(x)).add(c);
* const loss = (pred, label) => pred.sub(label).square().mean();
*
* const learningRate = 0.01;
* const optimizer = tf.train.sgd(learningRate);
*
* // Train the model.
* for (let i = 0; i < 10; i++) {
* optimizer.minimize(() => loss(f(xs), ys));
* }
*
* // Make predictions.
* console.log(
* `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`);
* const preds = f(xs).dataSync();
* preds.forEach((pred, i) => {
* console.log(`x: ${i}, pred: ${pred}`);
* });
* ```
*
* @param learningRate The learning rate to use for the SGD algorithm.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
OptimizerConstructors.sgd = function sgd(learningRate) {
return new SGDOptimizer(learningRate);
}
/**
* Constructs a `tf.MomentumOptimizer` that uses momentum gradient
* descent.
*
* See
* [http://proceedings.mlr.press/v28/sutskever13.pdf](
* http://proceedings.mlr.press/v28/sutskever13.pdf)
*
* @param learningRate The learning rate to use for the Momentum gradient
* descent algorithm.
* @param momentum The momentum to use for the momentum gradient descent
* algorithm.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.momentum = function momentum(learningRate, _momentum, useNesterov) {
if (useNesterov === void 0) {
useNesterov = false;
}
return new MomentumOptimizer(learningRate, _momentum, useNesterov);
}
/**
* Constructs a `tf.RMSPropOptimizer` that uses RMSProp gradient
* descent. This implementation uses plain momentum and is not centered
* version of RMSProp.
*
* See
* [http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf](
* http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
*
* @param learningRate The learning rate to use for the RMSProp gradient
* descent algorithm.
* @param decay The discounting factor for the history/coming gradient.
* @param momentum The momentum to use for the RMSProp gradient descent
* algorithm.
* @param epsilon Small value to avoid zero denominator.
* @param centered If true, gradients are normalized by the estimated
* variance of the gradient.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.rmsprop = function rmsprop(learningRate, decay, momentum, epsilon, centered) {
if (decay === void 0) {
decay = .9;
}
if (momentum === void 0) {
momentum = 0.0;
}
if (epsilon === void 0) {
epsilon = null;
}
if (centered === void 0) {
centered = false;
}
return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
}
/**
* Constructs a `tf.AdamOptimizer` that uses the Adam algorithm.
* See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
*
* @param learningRate The learning rate to use for the Adam gradient
* descent algorithm.
* @param beta1 The exponential decay rate for the 1st moment estimates.
* @param beta2 The exponential decay rate for the 2nd moment estimates.
* @param epsilon A small constant for numerical stability.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.adam = function adam(learningRate, beta1, beta2, epsilon) {
if (learningRate === void 0) {
learningRate = 0.001;
}
if (beta1 === void 0) {
beta1 = 0.9;
}
if (beta2 === void 0) {
beta2 = 0.999;
}
if (epsilon === void 0) {
epsilon = null;
}
return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
}
/**
* Constructs a `tf.AdadeltaOptimizer` that uses the Adadelta algorithm.
* See [https://arxiv.org/abs/1212.5701](https://arxiv.org/abs/1212.5701)
*
* @param learningRate The learning rate to use for the Adadelta gradient
* descent algorithm.
* @param rho The learning rate decay over each update.
* @param epsilon A constant epsilon used to better condition the grad
* update.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.adadelta = function adadelta(learningRate, rho, epsilon) {
if (learningRate === void 0) {
learningRate = .001;
}
if (rho === void 0) {
rho = .95;
}
if (epsilon === void 0) {
epsilon = null;
}
return new AdadeltaOptimizer(learningRate, rho, epsilon);
}
/**
* Constructs a `tf.AdamaxOptimizer` that uses the Adamax algorithm.
* See [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980)
*
* @param learningRate The learning rate to use for the Adamax gradient
* descent algorithm.
* @param beta1 The exponential decay rate for the 1st moment estimates.
* @param beta2 The exponential decay rate for the 2nd moment estimates.
* @param epsilon A small constant for numerical stability.
* @param decay The learning rate decay over each update.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.adamax = function adamax(learningRate, beta1, beta2, epsilon, decay) {
if (learningRate === void 0) {
learningRate = 0.002;
}
if (beta1 === void 0) {
beta1 = 0.9;
}
if (beta2 === void 0) {
beta2 = 0.999;
}
if (epsilon === void 0) {
epsilon = null;
}
if (decay === void 0) {
decay = 0.0;
}
return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
}
/**
* Constructs a `tf.AdagradOptimizer` that uses the Adagrad algorithm.
* See
* [http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf](
* http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
* or
* [http://ruder.io/optimizing-gradient-descent/index.html#adagrad](
* http://ruder.io/optimizing-gradient-descent/index.html#adagrad)
*
* @param learningRate The learning rate to use for the Adagrad gradient
* descent algorithm.
* @param initialAccumulatorValue Starting value for the accumulators, must be
* positive.
*
* @doc {heading: 'Training', subheading: 'Optimizers', namespace: 'train'}
*/
;
OptimizerConstructors.adagrad = function adagrad(learningRate, initialAccumulatorValue) {
if (initialAccumulatorValue === void 0) {
initialAccumulatorValue = 0.1;
}
return new AdagradOptimizer(learningRate, initialAccumulatorValue);
};
return OptimizerConstructors;
}();
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
[MomentumOptimizer, SGDOptimizer, AdadeltaOptimizer, AdagradOptimizer, RMSPropOptimizer, AdamaxOptimizer, AdamOptimizer];
var train = {
sgd: OptimizerConstructors.sgd,
momentum: OptimizerConstructors.momentum,
adadelta: OptimizerConstructors.adadelta,
adagrad: OptimizerConstructors.adagrad,
rmsprop: OptimizerConstructors.rmsprop,
adamax: OptimizerConstructors.adamax,
adam: OptimizerConstructors.adam
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var delayCallback = function () {
if (typeof requestAnimationFrame !== 'undefined') {
return requestAnimationFrame;
} else if (typeof setImmediate !== 'undefined') {
return setImmediate;
}
return function (f) {
return f();
}; // no delays
}();
/**
* Returns a promise that resolve when a requestAnimationFrame has completed.
*
* On Node.js this uses setImmediate instead of requestAnimationFrame.
*
* This is simply a sugar method so that users can do the following:
* `await tf.nextFrame();`
*
* @doc {heading: 'Performance', subheading: 'Timing'}
*/
function nextFrame() {
return new Promise(function (resolve) {
return delayCallback(function () {
return resolve();
});
});
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertParamsConsistent(shapes, axis) {
var rank = shapes[0].length;
shapes.forEach(function (shape, i) {
assert(shape.length === rank, function () {
return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " + ("as the rank of the rest (" + rank + ")");
});
});
assert(axis >= 0 && axis < rank, function () {
return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + ".";
});
var firstShape = shapes[0];
shapes.forEach(function (shape, i) {
for (var r = 0; r < rank; r++) {
assert(r === axis || shape[r] === firstShape[r], function () {
return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " + ("does not match the shape of the rest (" + firstShape + ") ") + ("along the non-concatenated axis " + i + ".");
});
}
});
}
function computeOutShape$1(shapes, axis) {
var outputShape = shapes[0].slice();
for (var i = 1; i < shapes.length; i++) {
outputShape[axis] += shapes[i][axis];
}
return outputShape;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PARALLELIZE_THRESHOLD = 30;
function computeOptimalWindowSize(inSize) {
if (inSize <= PARALLELIZE_THRESHOLD) {
return inSize;
}
return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// Returns the image center in pixels.
function getImageCenter(center, imageHeight, imageWidth) {
var centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
var centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
return [centerX, centerY];
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Gets the new shape of the input Tensor after it's been reshaped
* to:
* [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape),
* inputShape[1], ..., inputShape[N-1]]
*
* See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getReshaped(inputShape, blockShape, prod, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var reshaped = [];
if (batchToSpace) {
reshaped = reshaped.concat(blockShape.slice(0));
reshaped.push(inputShape[0] / prod);
reshaped = reshaped.concat(inputShape.slice(1));
} else {
reshaped = reshaped.concat(inputShape[0]);
var spatialLength = blockShape.length;
for (var i = 0; i < spatialLength; ++i) {
reshaped = reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
}
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
}
return reshaped;
}
/**
* Gets the permutation that will transpose the dimensions of the
* reshaped tensor to shape:
*
* [batch / prod(block_shape),inputShape[1], blockShape[0], ...,
* inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
*
* see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getPermuted(reshapedRank, blockShapeRank, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var permuted = [];
if (batchToSpace) {
permuted.push(blockShapeRank);
for (var i = blockShapeRank + 1; i < reshapedRank; ++i) {
if (i <= 2 * blockShapeRank) {
permuted.push(i);
permuted.push(i - (blockShapeRank + 1));
} else {
permuted.push(i);
}
}
} else {
var permutedBeforeBatch = [];
var permutedAfterBatch = [];
for (var _i = 1; _i < reshapedRank; ++_i) {
if (_i >= blockShapeRank * 2 + 1 || _i % 2 === 1) {
permutedAfterBatch.push(_i);
} else {
permutedBeforeBatch.push(_i);
}
}
permuted.push.apply(permuted, permutedBeforeBatch);
permuted.push(0);
permuted.push.apply(permuted, permutedAfterBatch);
}
return permuted;
}
/**
* Gets the shape of the reshaped and permuted input Tensor before any cropping
* is applied. The new shape will be:
*
* [batch / prod(blockShape),inputShape[1] * blockShape[0], ...,
* inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]]
*
* See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) {
if (batchToSpace === void 0) {
batchToSpace = true;
}
var reshapedPermuted = [];
if (batchToSpace) {
reshapedPermuted.push(inputShape[0] / prod);
} else {
reshapedPermuted.push(inputShape[0] * prod);
}
for (var i = 1; i < inputShape.length; ++i) {
if (i <= blockShape.length) {
if (batchToSpace) {
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
} else {
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
}
} else {
reshapedPermuted.push(inputShape[i]);
}
}
return reshapedPermuted;
}
/**
* Converts the crops argument into the beginning coordinates of a slice
* operation.
*/
function getSliceBeginCoords(crops, blockShape) {
var sliceBeginCoords = [0];
for (var i = 0; i < blockShape; ++i) {
sliceBeginCoords.push(crops[i][0]);
}
return sliceBeginCoords;
}
/**
* Converts the crops argument into the size of a slice operation. When
* combined with getSliceBeginCoords this function allows the reshaped and
* permuted Tensor to be cropped to its final output shape of:
*
* inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ...,
* inputShape[M] * blockShape[M-1] -crops[M-1,0] -
* crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]]
*
* See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd
*/
function getSliceSize(uncroppedShape, crops, blockShape) {
var sliceSize = uncroppedShape.slice(0, 1);
for (var i = 0; i < blockShape; ++i) {
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
}
return sliceSize;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SELU_SCALEALPHA = 1.7580993408473768599402175208123;
var SELU_SCALE = 1.0507009873554804934193349852946;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ERF_P = 0.3275911;
var ERF_A1 = 0.254829592;
var ERF_A2 = -0.284496736;
var ERF_A3 = 1.421413741;
var ERF_A4 = -1.453152027;
var ERF_A5 = 1.061405429;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Merges real and imaginary Float32Arrays into a single complex Float32Array.
*
* The memory layout is interleaved as follows:
* real: [r0, r1, r2]
* imag: [i0, i1, i2]
* complex: [r0, i0, r1, i1, r2, i2]
*
* This is the inverse of splitRealAndImagArrays.
*
* @param real The real values of the complex tensor values.
* @param imag The imag values of the complex tensor values.
* @returns A complex tensor as a Float32Array with merged values.
*/
function mergeRealAndImagArrays(real, imag) {
if (real.length !== imag.length) {
throw new Error("Cannot merge real and imag arrays of different lengths. real:" + (real.length + ", imag: " + imag.length + "."));
}
var result = new Float32Array(real.length * 2);
for (var i = 0; i < result.length; i += 2) {
result[i] = real[i / 2];
result[i + 1] = imag[i / 2];
}
return result;
}
/**
* Splits a complex Float32Array into real and imag parts.
*
* The memory layout is interleaved as follows:
* complex: [r0, i0, r1, i1, r2, i2]
* real: [r0, r1, r2]
* imag: [i0, i1, i2]
*
* This is the inverse of mergeRealAndImagArrays.
*
* @param complex The complex tensor values.
* @returns An object with real and imag Float32Array components of the complex
* tensor.
*/
function splitRealAndImagArrays(complex) {
var real = new Float32Array(complex.length / 2);
var imag = new Float32Array(complex.length / 2);
for (var i = 0; i < complex.length; i += 2) {
real[i / 2] = complex[i];
imag[i / 2] = complex[i + 1];
}
return {
real: real,
imag: imag
};
}
/**
* Extracts even indexed complex values in the given array.
* @param complex The complex tensor values
*/
function complexWithEvenIndex(complex) {
var len = Math.ceil(complex.length / 4);
var real = new Float32Array(len);
var imag = new Float32Array(len);
for (var i = 0; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return {
real: real,
imag: imag
};
}
/**
* Extracts odd indexed comple values in the given array.
* @param complex The complex tensor values
*/
function complexWithOddIndex(complex) {
var len = Math.floor(complex.length / 4);
var real = new Float32Array(len);
var imag = new Float32Array(len);
for (var i = 2; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return {
real: real,
imag: imag
};
}
/**
* Get the map representing a complex value in the given array.
* @param complex The complex tensor values.
* @param index An index of the target complex value.
*/
function getComplexWithIndex(complex, index) {
var real = complex[index * 2];
var imag = complex[index * 2 + 1];
return {
real: real,
imag: imag
};
}
/**
* Insert a given complex value into the TypedArray.
* @param data The array in which the complex value is inserted.
* @param c The complex value to be inserted.
* @param index An index of the target complex value.
*/
function assignToTypedArray(data, real, imag, index) {
data[index * 2] = real;
data[index * 2 + 1] = imag;
}
/**
* Make the list of exponent terms used by FFT.
*/
function exponents(n, inverse) {
var real = new Float32Array(n / 2);
var imag = new Float32Array(n / 2);
for (var i = 0; i < Math.ceil(n / 2); i++) {
var x = (inverse ? 2 : -2) * Math.PI * (i / n);
real[i] = Math.cos(x);
imag[i] = Math.sin(x);
}
return {
real: real,
imag: imag
};
}
/**
* Make the exponent term used by FFT.
*/
function exponent(k, n, inverse) {
var x = (inverse ? 2 : -2) * Math.PI * (k / n);
var real = Math.cos(x);
var imag = Math.sin(x);
return {
real: real,
imag: imag
};
}
var ARROW = '->';
var ARROW_REGEX = /->/g;
var COMMA = ',';
var ELLIPSIS = '...';
/**
* Parse an equation for einsum.
*
* @param equation The einsum equation (e.g., "ij,jk->ik").
* @param numTensors Number of tensors provided along with `equation`. Used to
* check matching number of input tensors.
* @returns An object consisting of the following fields:
* - allDims: all dimension names as strings.
* - summedDims: a list of all dimensions being summed over, as indices to
* the elements of `allDims`.
* - idDims: indices of the dimensions in each input tensor, as indices to
* the elements of `allDims.
*/
function decodeEinsumEquation(equation, numTensors) {
equation = equation.replace(/\s/g, ''); // Remove witespace in equation.
var numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) / ARROW.length;
if (numArrows < 1) {
throw new Error('Equations without an arrow are not supported.');
} else if (numArrows > 1) {
throw new Error("Equation must contain exactly one arrow (\"" + ARROW + "\").");
}
var _equation$split = equation.split(ARROW),
inputString = _equation$split[0],
outputString = _equation$split[1];
assert(inputString.indexOf(ELLIPSIS) === -1, function () {
return "The ellipsis notation (\"" + ELLIPSIS + "\") is not supported yet.";
});
var inputTerms = inputString.split(COMMA);
var numInputs = inputTerms.length;
if (numTensors !== numInputs) {
throw new Error("Expected " + numInputs + " input tensors, received " + numTensors);
}
if (numInputs > 2) {
throw new Error('Support for more than 2 input tensors is not implemented yet.');
}
var allDims = [];
var _loop = function _loop(i) {
var dimName = outputString[i];
if (!inputTerms.some(function (inputTerm) {
return inputTerm.indexOf(dimName) !== -1;
})) {
throw new Error("Output subscripts contain the label " + dimName + " " + "not present in the input subscripts.");
}
if (allDims.indexOf(dimName) === -1) {
allDims.push(dimName);
}
};
for (var i = 0; i < outputString.length; ++i) {
_loop(i);
}
for (var _i = 0; _i < inputString.length; ++_i) {
var dimName = inputString[_i];
if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
allDims.push(dimName);
}
}
var idDims = new Array(inputTerms.length);
for (var _i2 = 0; _i2 < numInputs; ++_i2) {
if (new Set(inputTerms[_i2].split('')).size !== inputTerms[_i2].length) {
throw new Error("Found duplicate axes in input component " + inputTerms[_i2] + ". " + "Support for duplicate axes in input is not implemented yet.");
}
idDims[_i2] = [];
for (var j = 0; j < inputTerms[_i2].length; ++j) {
idDims[_i2].push(allDims.indexOf(inputTerms[_i2][j]));
}
}
var numDims = allDims.length; // Number of unique dimensions.
var numOutDims = outputString.length; // Number of output dimensions.
var summedDims = []; // Dimensions being summed over.
for (var _i3 = numOutDims; _i3 < numDims; ++_i3) {
summedDims.push(_i3);
}
return {
allDims: allDims,
summedDims: summedDims,
idDims: idDims
};
}
/**
* Get the permutation for a given input tensor.
*
* @param nDims Total number of dimension of all tensors involved in the einsum
* operation.
* @param idDims Dimension indices involve in the tensor in question.
* @returns An object consisting of the following fields:
* - permutationIndices: Indices to permute the axes of the tensor with.
* - expandDims: Indices to the dimension that need to be expanded from the
* tensor after permutation.
*/
function getEinsumPermutation(nDims, idDims) {
var permutationIndices = new Array(nDims);
permutationIndices.fill(-1);
for (var i = 0; i < idDims.length; ++i) {
permutationIndices[idDims[i]] = i;
}
var expandDims = [];
for (var _i4 = 0; _i4 < nDims; ++_i4) {
if (permutationIndices[_i4] === -1) {
expandDims.push(_i4);
}
}
permutationIndices = permutationIndices.filter(function (d) {
return d !== -1;
});
return {
permutationIndices: permutationIndices,
expandDims: expandDims
};
}
/**
* Checks that the dimension sizes from different input tensors match the
* equation.
*/
function checkEinsumDimSizes(nDims, idDims, tensors) {
var dimSizes = new Array(nDims);
var _loop2 = function _loop2(i) {
var shape = tensors[i].shape;
var _loop3 = function _loop3(j) {
if (dimSizes[idDims[i][j]] === undefined) {
dimSizes[idDims[i][j]] = shape[j];
} else {
assert(dimSizes[idDims[i][j]] === shape[j], function () {
return "Expected dimension " + dimSizes[idDims[i][j]] + " at axis " + j + " " + ("of input shaped " + JSON.stringify(shape) + ", ") + ("but got dimension " + shape[j]);
});
}
};
for (var j = 0; j < idDims[i].length; ++j) {
_loop3(j);
}
};
for (var i = 0; i < tensors.length; ++i) {
_loop2(i);
}
}
/**
* Gets path of computation for einsum.
*
* @param summedDims indices to the dimensions being summed over.
* @param idDims A look up table for the dimensions present in each input
* tensor. Each consituent array contains indices for the dimensions in the
* corresponding input tensor.
*
* @return A map with two fields:
* - path: The path of computation, with each element indicating the dimension
* being summed over after the element-wise multiplication in that step.
* - steps: With the same length as `path`. Each element contains the indices
* to the input tensors being used for element-wise multiplication in the
* corresponding step.
*/
function getEinsumComputePath(summedDims, idDims) {
var path = summedDims;
var steps = [];
var nSteps = 0;
if (summedDims.length === 0) {
// Einsum that involes no summing: e.g., transpose and outer product.
path.push(-1);
}
nSteps = summedDims.length + 1;
for (var i = 0; i < nSteps; ++i) {
steps.push([]);
}
var computedTermIndices = [];
for (var _i5 = 0; _i5 < path.length; ++_i5) {
var summedDim = path[_i5];
var termIndices = findTermsWithDim(idDims, summedDim);
for (var _iterator = _createForOfIteratorHelperLoose(termIndices), _step; !(_step = _iterator()).done;) {
var termIndex = _step.value;
if (computedTermIndices.indexOf(termIndex) === -1) {
steps[_i5].push(termIndex);
computedTermIndices.push(termIndex);
}
}
}
return {
path: path,
steps: steps
};
}
/** Determines if an axes permutation is the identity permutation. */
function isIdentityPermutation(perm) {
return perm.every(function (dim, index) {
return dim === index;
});
}
function findTermsWithDim(idDims, dim) {
var termIndices = [];
for (var i = 0; i < idDims.length; ++i) {
if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
termIndices.push(i);
}
}
return termIndices;
}
/**
* Prepare the split size array. When the input is a number, the axis is evenly
* divided among the split size. When the input contains the negative value, the
* rest of the axis is allocated toward that.
*/
function prepareSplitSize(x, numOrSizeSplits, axis) {
if (axis === void 0) {
axis = 0;
}
var splitSizes = [];
if (typeof numOrSizeSplits === 'number') {
assert(x.shape[axis] % numOrSizeSplits === 0, function () {
return 'Number of splits must evenly divide the axis.';
});
splitSizes = new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
} else {
var numOfNegs = numOrSizeSplits.reduce(function (count, value) {
if (value === -1) {
count += 1;
}
return count;
}, 0);
assert(numOfNegs <= 1, function () {
return 'There should be only one negative value in split array.';
});
var negIndex = numOrSizeSplits.indexOf(-1); // Allow the number of split array to be -1, which indicates the rest
// of dimension is allocated to that split.
if (negIndex !== -1) {
var total = numOrSizeSplits.reduce(function (a, b) {
return b > 0 ? a + b : a;
});
numOrSizeSplits[negIndex] = x.shape[axis] - total;
}
assert(x.shape[axis] === numOrSizeSplits.reduce(function (a, b) {
return a + b;
}), function () {
return 'The sum of sizes must match the size of the axis dimension.';
});
splitSizes = numOrSizeSplits;
}
return splitSizes;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function segOpComputeOptimalWindowSize(inSize, numSegments) {
var done = false;
var res;
if (inSize <= PARALLELIZE_THRESHOLD) {
res = inSize;
done = true;
} else {
res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
while (!done) {
if (res > numSegments || res === inSize) {
done = true;
} else {
res = nearestDivisor(inSize, res + 1);
}
}
return res;
}
function computeOutShape$2(aShape, axis, numSegments) {
var outShape = [];
var rank = aShape.length;
for (var dim = 0; dim < rank; dim++) {
if (dim !== axis) {
outShape.push(aShape[dim]);
} else {
outShape.push(numSegments);
}
}
return outShape;
}
function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
var indicesRank = indices.shape.length;
var xRank = x.shape.length;
if (batchDims !== 0) {
if (batchDims < -indicesRank || batchDims > indicesRank) {
throw new Error("Expect batchDims in the range of [-" + indicesRank + ", " + indicesRank + "], but got " + batchDims);
}
}
if (batchDims < 0) {
batchDims += indicesRank;
}
if (batchDims > xRank) {
throw new Error("batchDims (" + batchDims + ") must be less than rank(x) (\n " + xRank + ").");
}
if (axis < batchDims) {
throw new Error("batchDims (" + batchDims + ") must be less than or equal to axis (" + axis + ").");
}
for (var i = 0; i < batchDims; ++i) {
if (x.shape[i] !== indices.shape[i]) {
throw new Error("x.shape[" + i + "]: " + x.shape[i] + " should be equal to indices.shape[" + i + "]: " + indices.shape[i] + ".");
}
}
var dimSize = x.shape[axis];
var outputShape = [];
var batchSize = 1;
var outerSize = 1;
var sliceSize = 1;
for (var _i = 0; _i < batchDims; ++_i) {
outputShape.push(x.shape[_i]);
batchSize *= x.shape[_i];
}
for (var _i2 = batchDims; _i2 < axis; _i2++) {
outputShape.push(x.shape[_i2]);
outerSize *= x.shape[_i2];
}
for (var _i3 = batchDims; _i3 < indicesRank; _i3++) {
outputShape.push(indices.shape[_i3]);
}
for (var _i4 = axis + 1; _i4 < xRank; _i4++) {
outputShape.push(x.shape[_i4]);
sliceSize *= x.shape[_i4];
}
return {
batchSize: batchSize,
sliceSize: sliceSize,
outerSize: outerSize,
dimSize: dimSize,
outputShape: outputShape
};
}
var segment_util = {
__proto__: null,
segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize,
computeOutShape: computeOutShape$2,
collectGatherOpShapeInfo: collectGatherOpShapeInfo
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fromUint8ToStringArray(vals) {
try {
// Decode the bytes into string.
return vals.map(function (val) {
return decodeString(val);
});
} catch (err) {
throw new Error("Failed to decode encoded string bytes into utf-8, error: " + err);
}
}
function fromStringArrayToUint8(strings) {
return strings.map(function (s) {
return encodeString(s);
});
}
var backend_util = {
__proto__: null,
slice_util: slice_util,
segment_util: segment_util,
fromUint8ToStringArray: fromUint8ToStringArray,
fromStringArrayToUint8: fromStringArrayToUint8,
upcastType: upcastType,
axesAreInnerMostDims: axesAreInnerMostDims,
combineLocations: combineLocations,
computeOutAndReduceShapes: computeOutAndReduceShapes,
expandShapeToKeepDim: expandShapeToKeepDim,
assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
getAxesPermutation: getAxesPermutation,
getUndoAxesPermutation: getUndoAxesPermutation,
getInnerMostAxes: getInnerMostAxes,
getBroadcastDims: getBroadcastDims,
getReductionAxes: getReductionAxes,
assertAndGetBroadcastShape: assertAndGetBroadcastShape,
assertParamsConsistent: assertParamsConsistent,
computeOutShape: computeOutShape$1,
computeDilation2DInfo: computeDilation2DInfo,
computePool2DInfo: computePool2DInfo,
computePool3DInfo: computePool3DInfo,
computeConv2DInfo: computeConv2DInfo,
computeConv3DInfo: computeConv3DInfo,
computeDefaultPad: computeDefaultPad,
tupleValuesAreOne: tupleValuesAreOne,
eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
convertConv2DDataFormat: convertConv2DDataFormat,
getFusedDyActivation: getFusedDyActivation,
getFusedBiasGradient: getFusedBiasGradient,
applyActivation: applyActivation,
shouldFuse: shouldFuse,
PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
computeOptimalWindowSize: computeOptimalWindowSize,
getImageCenter: getImageCenter,
getReshaped: getReshaped,
getPermuted: getPermuted,
getReshapedPermuted: getReshapedPermuted,
getSliceBeginCoords: getSliceBeginCoords,
getSliceSize: getSliceSize,
prepareAndValidate: prepareAndValidate,
validateUpdateShape: validateUpdateShape,
validateInput: validateInput,
calculateShapes: calculateShapes,
SELU_SCALEALPHA: SELU_SCALEALPHA,
SELU_SCALE: SELU_SCALE,
ERF_P: ERF_P,
ERF_A1: ERF_A1,
ERF_A2: ERF_A2,
ERF_A3: ERF_A3,
ERF_A4: ERF_A4,
ERF_A5: ERF_A5,
warn: warn,
log: log$9,
mergeRealAndImagArrays: mergeRealAndImagArrays,
splitRealAndImagArrays: splitRealAndImagArrays,
complexWithEvenIndex: complexWithEvenIndex,
complexWithOddIndex: complexWithOddIndex,
getComplexWithIndex: getComplexWithIndex,
assignToTypedArray: assignToTypedArray,
exponents: exponents,
exponent: exponent,
decodeEinsumEquation: decodeEinsumEquation,
getEinsumPermutation: getEinsumPermutation,
checkEinsumDimSizes: checkEinsumDimSizes,
getEinsumComputePath: getEinsumComputePath,
isIdentityPermutation: isIdentityPermutation,
prepareSplitSize: prepareSplitSize
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernel_impls = {
__proto__: null,
nonMaxSuppressionV3Impl: nonMaxSuppressionV3Impl,
nonMaxSuppressionV4Impl: nonMaxSuppressionV4Impl,
nonMaxSuppressionV5Impl: nonMaxSuppressionV5Impl,
whereImpl: whereImpl
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var absGradConfig = {
kernelName: Abs,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(dy, step(cast(_x, 'float32'), -1));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acosGradConfig = {
kernelName: Acos,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
var a = square(cast(_x, 'float32'));
var b = sqrt$3(sub(scalar(1), a));
return neg(div(dy, b));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acoshGradConfig = {
kernelName: Acosh,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
var a = sqrt$3(sub(square(cast(_x, 'float32')), 1));
return div(dy, a);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addGradConfig = {
kernelName: Add,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var res = dy;
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function derB() {
var res = dy;
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, b.shape);
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addNGradConfig = {
kernelName: AddN,
saveAllInputs: true,
gradFunc: function gradFunc(dy, saved) {
var ders = {};
saved.forEach(function (_, i) {
ders[i] = function () {
return dy.clone();
};
});
return ders;
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMaxGradConfig = {
kernelName: ArgMax,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return zerosLike(_x);
}
};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var argMinGradConfig = {
kernelName: ArgMin,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return zerosLike(_x);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinGradConfig = {
kernelName: Asin,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, sqrt$3(sub(scalar(1), square(cast(_x, 'float32')))));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinhGradConfig = {
kernelName: Asinh,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
var a = sqrt$3(add$1(scalar(1), square(cast(_x, 'float32'))));
return div(dy, a);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan2GradConfig = {
kernelName: Atan2,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var d = add$1(square(a), square(b));
var res = mul(dy, div(b, d));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function derB() {
var d = add$1(square(a), square(b));
var res = neg(mul(dy, div(a, d)));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, b.shape);
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanGradConfig = {
kernelName: Atan,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, add$1(square(cast(_x, 'float32')), 1));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanhGradConfig = {
kernelName: Atanh,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, sub(scalar(1), square(cast(_x, 'float32'))));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the backprop of a 3d avg pool.
*
* @param dy The dy error, of rank 5 of shape
* [batchSize, depth, height, width, channels].
* assumed.
* @param input The original input image, of rank 5 or rank4 of shape
* [batchSize, depth, height, width, channels].
* @param filterSize The filter size:
* `[filterDepth, filterHeight, filterWidth]`.
* `filterSize` is a single number,
* then `filterDepth == filterHeight == filterWidth`.
* @param strides The strides of the pooling:
* `[strideDepth, strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param pad A string from: 'same', 'valid'. The type of padding algorithm
* used in the forward prop of the op.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
var $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
var $input = convertToTensor(input, 'input', 'avgPool3dGrad');
var dy5D = $dy;
var input5D = $input;
var reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]]);
}
assert(dy5D.rank === 5, function () {
return "Error in avgPool3dGrad: dy must be rank 5 but got rank " + (dy5D.rank + ".");
});
assert(input5D.rank === 5, function () {
return "Error in avgPool3dGrad: input must be rank 5 but got rank " + (input5D.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in avgPool3dGrad: pad must be an integer when " + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
dy: dy5D,
input: input5D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var avgPool3dGrad = op({
avgPool3dGrad_: avgPool3dGrad_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var avgPool3DGradConfig = {
kernelName: AvgPool3D,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0];
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
return {
x: function x() {
return avgPool3dGrad(dy, _x, filterSize, strides, pad, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the backprop of an 2D avg pool.
*
* @param dy The dy error, of rank 4 or rank 3 of shape
* [batchSize, height, width, channels]. If rank 3, batch of 1 is
* assumed.
* @param input The input image, of rank 4 or rank 3 of shape
* [batchSize, height, width, channels]. If rank 3, batch of 1 is
* assumed.
* @param filterSize The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param pad The type of padding algorithm used in the forward prop of the op.
* 'same', 'valid', for more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
*/
function avgPoolGrad_(dy, input, filterSize, strides, pad) {
var $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
var $input = convertToTensor(input, 'input', 'avgPoolGrad');
assert($input.rank === $dy.rank, function () {
return "Rank of input (" + $input.rank + ") does not match rank of dy (" + $dy.rank + ")";
});
var input4D = $input;
var dy4D = $dy;
var reshapedTo4D = false;
if ($input.rank === 3) {
reshapedTo4D = true;
input4D = reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
dy4D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
}
assert(dy4D.rank === 4, function () {
return "Error in avgPoolGrad: dy must be rank 4 but got rank " + (dy4D.rank + ".");
});
assert(input4D.rank === 4, function () {
return "Error in avgPoolGrad: input must be rank 4 but got rank " + (input4D.rank + ".");
});
var inputs = {
dy: dy4D,
input: input4D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
if (reshapedTo4D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
var avgPoolGrad = op({
avgPoolGrad_: avgPoolGrad_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var avgPoolGradConfig = {
kernelName: AvgPool,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0];
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad;
return {
x: function x() {
return avgPoolGrad(dy, _x, filterSize, strides, pad);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchMatMulGradConfig = {
kernelName: BatchMatMul,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved, attrs) {
var a = saved[0],
b = saved[1];
var transposeA = attrs.transposeA,
transposeB = attrs.transposeB;
if (!transposeA && !transposeB) {
return {
a: function a() {
return matMul(dy, b, false, true);
},
b: function b() {
return matMul(a, dy, true, false);
}
};
} else if (!transposeA && transposeB) {
return {
a: function a() {
return matMul(dy, b, false, false);
},
b: function b() {
return matMul(dy, a, true, false);
}
};
} else if (transposeA && !transposeB) {
return {
a: function a() {
return matMul(b, dy, false, true);
},
b: function b() {
return matMul(a, dy, false, false);
}
};
} else {
return {
a: function a() {
return matMul(b, dy, true, true);
},
b: function b() {
return matMul(dy, a, true, true);
}
};
}
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchToSpaceNDGradConfig = {
kernelName: BatchToSpaceND,
gradFunc: function gradFunc(dy, saved, attrs) {
var blockShape = attrs.blockShape,
crops = attrs.crops;
return {
x: function x() {
return spaceToBatchND(dy, blockShape, crops);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var broadcastToGradConfig = {
kernelName: BroadcastTo,
gradFunc: function gradFunc(dy, saved, attrs) {
var broadCastToAttrs = attrs;
var inputShape = broadCastToAttrs.inputShape;
var outputShape = broadCastToAttrs.shape;
var reps = Array.from(outputShape);
for (var i = inputShape.length - 1; i >= 0; i--) {
if (inputShape[i] === outputShape[i]) {
reps[i] = 1;
} else if (inputShape[i] !== 1) {
throw new Error("broadcastTo(): [" + inputShape + "] cannot be broadcast to [" + outputShape + "].");
}
}
var axes = [];
for (var _i = 0; _i < reps.length; _i++) {
if (reps[_i] > 1) {
axes.push(_i);
}
}
return {
x: function x() {
return sum$1(dy, axes, true
/* keepDims */
);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var castGradConfig = {
kernelName: Cast,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return dy.clone();
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ceilGradConfig = {
kernelName: Ceil,
gradFunc: function gradFunc(dy) {
// TODO(manrajgrover): Return null for gradients when backprop supports it.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var clipByValueGradConfig = {
kernelName: ClipByValue,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0];
var clipValueMin = attrs.clipValueMin,
clipValueMax = attrs.clipValueMax;
return {
x: function x() {
return where(logicalAnd(greaterEqual(_x, clipValueMin), lessEqual(_x, clipValueMax)), dy, zerosLike(dy));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var complexAbsGradConfig = {
kernelName: ComplexAbs,
inputsToSave: ['x'],
gradFunc: absGradConfig.gradFunc
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var concatGradConfig = {
kernelName: Concat,
saveAllInputs: true,
gradFunc: function gradFunc(dy, saved, attrs) {
var shapes = saved.map(function (t) {
return t.shape;
});
var axis = attrs.axis;
var $axis = parseAxisParam(axis, saved[0].shape)[0];
var sizeSplits = shapes.map(function (s) {
return s[$axis];
});
var derTensors = split$1(dy, sizeSplits, $axis);
return derTensors.map(function (t) {
return function () {
return t;
};
});
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv2DGradConfig = {
kernelName: Conv2D,
inputsToSave: ['x', 'filter'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x4D = saved[0],
$filter = saved[1];
var dilations = attrs.dilations,
strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat;
assert(tupleValuesAreOne(dilations), function () {
return 'Error in gradient of conv2D: dilation rates greater than 1 ' + ("are not yet supported in gradients. Got dilations '" + dilations + "'");
});
return {
x: function x() {
return conv2DBackpropInput(x4D.shape, dy, $filter, strides, pad, dataFormat);
},
filter: function filter() {
return conv2DBackpropFilter(x4D, dy, $filter.shape, strides, pad, dataFormat);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv2DBackpropInputGradConfig = {
kernelName: Conv2DBackpropInput,
inputsToSave: ['dy', 'filter'],
gradFunc: function gradFunc(ddx, saved, attrs) {
var dy = saved[0],
_filter = saved[1];
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode;
return {
dy: function dy() {
return conv2d(ddx, _filter, strides, pad, dataFormat, 1
/* dilations */
, dimRoundingMode);
},
filter: function filter() {
return conv2DBackpropFilter(ddx, dy, _filter.shape, strides, pad, dataFormat, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the derivative of the filter of a 3D convolution.
*
* @param x The input tensor, of rank 5 or rank 4 of shape
* [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is
* assumed.
* @param dy The dy image, of rank 5 or rank 4, of shape
* [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is
* assumed.
* @param filterShape The shape of the filter, length 5,
* [filterDepth, filterHeight, filterWidth, inDepth, outDepth].
* @param strides The strides of the convolution: [strideDepth, strideHeight,
* strideWidth].
* @param pad A string from: 'same', 'valid'. The type of padding algorithm
* used in the forward prop of the op.
*/
function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
var x5D = x;
if (x.rank === 4) {
x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
}
var dy5D = dy;
if (dy5D.rank === 4) {
dy5D = reshape(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
}
assert(x5D.rank === 5, function () {
return "Error in conv3dDerFilter: input must be rank 5, but got shape " + (x5D.shape + ".");
});
assert(dy5D.rank === 5, function () {
return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + (dy5D.shape + ".");
});
assert(filterShape.length === 5, function () {
return "Error in conv3dDerFilter: filterShape must be length 5, but got " + (filterShape + ".");
});
assert(x5D.shape[4] === filterShape[3], function () {
return "Error in conv3dDerFilter: depth of input " + x5D.shape[4] + ") must " + ("match input depth in filter (" + filterShape[3] + ".");
});
assert(dy5D.shape[4] === filterShape[4], function () {
return "Error in conv3dDerFilter: depth of dy (" + dy5D.shape[4] + ") must " + ("match output depth for filter (" + filterShape[4] + ").");
});
var inputs = {
x: x5D,
dy: dy5D
};
var attrs = {
strides: strides,
pad: pad,
filterShape: filterShape
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
}
var conv3DBackpropFilter = op({
conv3DBackpropFilter_: conv3DBackpropFilter_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var conv3DGradConfig = {
kernelName: Conv3D,
inputsToSave: ['x', 'filter'],
gradFunc: function gradFunc(dy, saved, attrs) {
var dilations = attrs.dilations,
strides = attrs.strides,
pad = attrs.pad;
assert(tupleValuesAreOne(dilations), function () {
return 'Error in gradient of conv3D: dilation rates greater than 1 are ' + ("not yet supported in gradients. Got dilations '" + dilations + "'");
});
var x5D = saved[0],
$filter = saved[1];
return {
x: function x() {
return conv3DBackpropInput(x5D.shape, dy, $filter, strides, pad);
},
filter: function filter() {
return conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cosGradConfig = {
kernelName: Cos,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(neg(sin(cast(_x, 'float32'))), dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var coshGradConfig = {
kernelName: Cosh,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(sinh(cast(_x, 'float32')), dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cumsumGradConfig = {
kernelName: Cumsum,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0];
var axis = attrs.axis,
exclusive = attrs.exclusive,
reverse = attrs.reverse;
return {
x: function x() {
var permutation = getAxesPermutation([axis], _x.rank);
var out = cumsum(dy, axis, exclusive, !reverse);
if (permutation != null) {
out = transpose(out, permutation);
}
return out;
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var depthwiseConv2dNativeGradConfig = {
kernelName: DepthwiseConv2dNative,
inputsToSave: ['x', 'filter'],
gradFunc: function gradFunc(dy, saved, attrs) {
var dilations = attrs.dilations,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var $dilations = dilations == null ? [1, 1] : dilations;
assert(tupleValuesAreOne($dilations), function () {
return 'Error in gradient of depthwiseConv2dNative: dilation rates ' + "greater than 1 are not yet supported. Got dilations " + ("'" + $dilations + "'");
});
var _x = saved[0],
_filter = saved[1];
assert(_x.rank === 4, function () {
return "Error in gradient of depthwiseConv2dNative: input must be " + ("rank 4, but got rank " + _x.rank + ".");
});
assert(_filter.rank === 4, function () {
return "Error in gradient of depthwiseConv2dNative: filter must be " + ("rank 4, but got rank " + _filter.rank + ".");
});
assert(_x.shape[3] === _filter.shape[2], function () {
return "Error in gradient of depthwiseConv2d: number of input " + ("channels (" + _x.shape[3] + ") must match the inChannels dimension ") + ("in filter " + _filter.shape[2] + ".");
});
assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
return 'Error in gradient of depthwiseConv2d: Either strides or ' + ("dilations must be 1. Got strides " + strides + " and dilations ") + ("'" + $dilations + "'.");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in depthwiseConv2d: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
return {
x: function x() {
return depthwiseConv2dNativeBackpropInput(_x.shape, dy, _filter, strides, pad, $dilations, dimRoundingMode);
},
filter: function filter() {
return depthwiseConv2dNativeBackpropFilter(_x, dy, _filter.shape, strides, pad, $dilations, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dGradConfig = {
kernelName: Dilation2D,
inputsToSave: ['x', 'filter'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0],
filter = saved[1];
var inputInputs = {
x: x,
filter: filter,
dy: dy
};
var filterInputs = {
x: x,
filter: filter,
dy: dy
};
return {
x: function x() {
return ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs);
},
filter: function filter() {
return ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var eluGradConfig = {
kernelName: Elu,
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved) {
var y = saved[0];
var inputs = {
dy: dy,
y: y
};
return {
x: function x() {
return ENGINE.runKernel(EluGrad, inputs);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var erfGradConfig = {
kernelName: Erf,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var x = saved[0];
var a = mul(exp$3(neg(square(x))), 2 / Math.sqrt(Math.PI));
return {
x: function x() {
return mul(dy, a);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expGradConfig = {
kernelName: Exp,
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved) {
var y = saved[0];
return {
x: function x() {
return mul(dy, y);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expandDimsGradConfig = {
kernelName: ExpandDims,
inputsToSave: ['input'],
gradFunc: function gradFunc(dy, saved) {
var _input = saved[0];
return {
input: function input() {
return reshape(dy, _input.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expm1GradConfig = {
kernelName: Expm1,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(dy, exp$3(_x));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorGradConfig = {
kernelName: Floor,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorDivGradConfig = {
kernelName: FloorDiv,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var res = div(dy, cast(b, 'float32'));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function derB() {
var res = mul(dy, cast(a, 'float32'));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape(sum$1(res, reduceAxes), b.shape);
}
var tmp = square(b);
return neg(div(res, cast(tmp, 'float32')));
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fusedBatchNormGradConfig = {
kernelName: FusedBatchNorm,
inputsToSave: ['x', 'mean', 'variance', 'scale'],
gradFunc: function gradFunc(dy, saved, attrs) {
var varianceEpsilon = attrs.varianceEpsilon;
var x = saved[0],
mean = saved[1],
variance = saved[2],
scale = saved[3];
var scaleValue = scale == null ? scalar(1) : scale;
var reductionAxes = getReductionAxes(mean.shape, x.shape);
var tileShape = [];
if (mean.rank === 1) {
for (var i = 0; i < x.shape.length - 1; ++i) {
tileShape.push(x.shape[i]);
}
tileShape.push(1);
}
var xMinusMean = sub(x, mean);
var dyTimesScaleValue = mul(dy, scaleValue);
var oneOverSqrtVariance = rsqrt(add$1(variance, scalar(varianceEpsilon)));
var minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
var derX = function derX() {
if (mean.rank === 1) {
return reshape(mul(mul(dy, tile(reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
} else {
return reshape(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
}
};
var derMean = function derMean() {
var meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
if (mean.rank === 1) {
meanDer = sum$1(meanDer, reductionAxes);
}
return reshape(meanDer, mean.shape);
};
var derVariance = function derVariance() {
var varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
if (mean.rank === 1) {
varianceDer = sum$1(varianceDer, reductionAxes);
}
return reshape(varianceDer, mean.shape);
};
var derScale = function derScale() {
var xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
var scaleDer = mul(dy, xMinusMean2TimesRsqrt);
if (mean.rank === 1) {
scaleDer = sum$1(scaleDer, reductionAxes);
}
return reshape(scaleDer, mean.shape);
};
var derOffset = function derOffset() {
var offsetDer = dy;
if (mean.rank === 1) {
offsetDer = sum$1(offsetDer, reductionAxes);
}
return reshape(offsetDer, mean.shape);
};
return {
x: derX,
mean: derMean,
variance: derVariance,
scale: derScale,
offset: derOffset
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var gatherGradConfig = {
kernelName: GatherV2,
inputsToSave: ['x', 'indices'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0],
_indices = saved[1];
var axis = attrs.axis;
var parsedAxis = parseAxisParam(axis, x.shape)[0];
var derX = function derX() {
var paramsShape = x.shape;
var indicesSize = _indices.size;
var outerShape = paramsShape.slice(0, parsedAxis);
var outerDims = outerShape.length;
var innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
var innerDims = innerShape.length;
var outerAxesIndices = arrayRange(0, outerDims);
var innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
var valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
var values = reshape(dy, valuesShape);
var reshapedIndices = reshape(_indices, [indicesSize]);
var transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
var valuesTranspose = transpose(values, transposeDims);
var paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
var invertTransposeDims = getUndoAxesPermutation(transposeDims);
paramsGrad = transpose(paramsGrad, invertTransposeDims);
return paramsGrad;
};
return {
x: derX,
indices: function indices() {
return _indices;
}
};
}
};
function arrayRange(start, stop) {
var result = [];
for (var i = start; i < stop; ++i) {
result.push(i);
}
return result;
}
function arrayConcat(arrays) {
var result = [];
for (var i = 0; i < arrays.length; ++i) {
for (var j = 0; j < arrays[i].length; ++j) {
result.push(arrays[i][j]);
}
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var greaterEqualGradConfig = {
kernelName: GreaterEqual,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var _a = saved[0],
_b = saved[1];
return {
a: function a() {
return zerosLike(_a);
},
b: function b() {
return zerosLike(_b);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var identityGradConfig = {
kernelName: Identity,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return cast(dy, 'float32');
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isFiniteGradConfig = {
kernelName: IsFinite,
gradFunc: function gradFunc(dy) {
// TODO(nsthorat): Let gradients be null for cases where we want to stop
// backpropgation.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isInfGradConfig = {
kernelName: IsInf,
gradFunc: function gradFunc(dy) {
// TODO(nsthorat): Let gradients be null for cases where we want to stop
// backpropgation.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isNanGradConfig = {
kernelName: IsNan,
gradFunc: function gradFunc(dy) {
// TODO(nsthorat): Let gradients be null for cases where we want to stop
// backpropgation.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var leakyReluGradConfig = {
kernelName: LeakyRelu,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0];
var alpha = attrs.alpha;
var mask = greater(x, 0); // Returns `gradients * (features > 0) + alpha * gradients * (features <=
// 0)`.
return {
x: function x() {
return where(mask, dy, mul(dy, alpha));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var log1pGradConfig = {
kernelName: Log1p,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, add$1(_x, 1));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logGradConfig = {
kernelName: Log,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, cast(_x, 'float32'));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logSoftmaxGradConfig = {
kernelName: LogSoftmax,
inputsToSave: [],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var value = saved[0];
var axis = attrs.axis;
return {
logits: function logits() {
var keepDims = true;
var softmax = exp$3(value);
return sub(dy, mul(sum$1(dy, axis, keepDims), softmax));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function localResponseNormalizationBackprop_(x, y, dy, depthRadius, bias, alpha, beta) {
if (depthRadius === void 0) {
depthRadius = 5;
}
if (bias === void 0) {
bias = 1;
}
if (alpha === void 0) {
alpha = 1;
}
if (beta === void 0) {
beta = 0.5;
}
var inputs = {
x: x,
y: y,
dy: dy
};
var attrs = {
depthRadius: depthRadius,
bias: bias,
alpha: alpha,
beta: beta
};
return ENGINE.runKernel(LRNGrad, inputs, attrs);
}
var localResponseNormalizationBackprop = op({
localResponseNormalizationBackprop_: localResponseNormalizationBackprop_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lrnGradConfig = {
kernelName: LRN,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0],
y = saved[1];
var depthRadius = attrs.depthRadius,
bias = attrs.bias,
alpha = attrs.alpha,
beta = attrs.beta;
return {
x: function x() {
return localResponseNormalizationBackprop(_x, y, dy, depthRadius, bias, alpha, beta);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Gradient helper function for the min and max operations.
*/
function gradForMinAndMax(dy, y, xOrig, origAxes) {
if (y.rank < xOrig.rank) {
y = reshape(y, expandShapeToKeepDim(y.shape, origAxes));
}
if (dy.rank < xOrig.rank) {
dy = reshape(dy, expandShapeToKeepDim(dy.shape, origAxes));
}
return {
x: function x() {
var dx = mul(dy, cast(equal(xOrig, y), dy.dtype));
return dx;
}
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxGradConfig = {
kernelName: Max,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var maxAttrs = attrs;
var reductionIndices = maxAttrs.reductionIndices;
var x = saved[0];
var y = saved[1];
var origAxes = parseAxisParam(reductionIndices, x.shape);
var maxGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: function x() {
return maxGrad['x']();
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maximumGradConfig = {
kernelName: Maximum,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var derA = function derA() {
return mul(dy, cast(greaterEqual(a, b), 'float32'));
};
var derB = function derB() {
return mul(dy, cast(less(a, b), 'float32'));
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the backprop of a 3d max pool.
*
* @param dy The dy error, of rank 5 of shape
* [batchSize, depth, height, width, channels].
* assumed.
* @param input The original input image, of rank 5 or rank 4 of shape
* [batchSize, depth, height, width, channels].
* @param output The original output image, of rank 5 of shape
* [batchSize, outDepth, outHeight, outWidth, channels].
* @param filterSize The filter size:
* `[filterDepth, filterHeight, filterWidth]`.
* `filterSize` is a single number,
* then `filterDepth == filterHeight == filterWidth`.
* @param strides The strides of the pooling:
* `[strideDepth, strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param pad A string from: 'same', 'valid'. The type of padding algorithm
* used in the forward prop of the op.
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
var $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
var $input = convertToTensor(input, 'input', 'maxPool3dGrad');
var $output = convertToTensor(output, 'output', 'maxPool3dGrad');
var dy5D = $dy;
var input5D = $input;
var output5D = $output;
var reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape($input, [1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]]);
output5D = reshape($output, [1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]]);
}
assert(dy5D.rank === 5, function () {
return "Error in maxPool3dGrad: dy must be rank 5 but got rank " + (dy5D.rank + ".");
});
assert(input5D.rank === 5, function () {
return "Error in maxPool3dGrad: input must be rank 5 but got rank " + (input5D.rank + ".");
});
assert(output5D.rank === 5, function () {
return "Error in maxPool3dGrad: output must be rank 5 but got rank " + (output5D.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in maxPool3dGrad: pad must be an integer when " + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
dy: dy5D,
input: input5D,
output: output5D
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
var res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
if (reshapedTo5D) {
return reshape(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
var maxPool3dGrad = op({
maxPool3dGrad_: maxPool3dGrad_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPool3DGradConfig = {
kernelName: MaxPool3D,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0],
y = saved[1];
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
return {
x: function x() {
return maxPool3dGrad(dy, _x, y, filterSize, strides, pad, dimRoundingMode);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Computes the backprop of a 2D max pool.
*
* @param dy The dy error, of rank 4 or rank 3 of shape
* [batchSize, height, width, channels]. If rank 3, batch of 1 is
* assumed.
* @param input The original input image, of rank 4, of shape
* [batchSize, height, width, channels].
* @param output The original output image, of rank 4, of shape
* [batchSize, outHeight, outWidth, channels].
* @param filterSize The filter size: `[filterHeight, filterWidth]`. If
* `filterSize` is a single number, then `filterHeight == filterWidth`.
* @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If
* `strides` is a single number, then `strideHeight == strideWidth`.
* @param pad The type of padding algorithm used in the forward prop of the op.
* 'same', 'valid', for more info, see this guide:
* [https://www.tensorflow.org/api_docs/python/tf/nn/convolution](
* https://www.tensorflow.org/api_docs/python/tf/nn/convolution)
* @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is
* provided, it will default to truncate.
*/
function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
var $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
var $input = convertToTensor(input, 'input', 'maxPoolGrad');
var $output = convertToTensor(output, 'output', 'maxPoolGrad');
assert($input.rank === $dy.rank, function () {
return "Rank of input (" + $input.rank + ") does not match rank of dy " + ("(" + $dy.rank + ")");
});
assert($dy.rank === 4, function () {
return "Error in maxPoolGrad: dy must be rank 4 but got rank " + ($dy.rank + ".");
});
assert($input.rank === 4, function () {
return "Error in maxPoolGrad: input must be rank 4 but got rank " + ($input.rank + ".");
});
if (dimRoundingMode != null) {
assert(isInt(pad), function () {
return "Error in maxPoolGrad: pad must be an integer when using, " + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + ".");
});
}
var inputs = {
dy: $dy,
input: $input,
output: $output
};
var attrs = {
filterSize: filterSize,
strides: strides,
pad: pad,
dimRoundingMode: dimRoundingMode
}; // tslint:disable-next-line: no-unnecessary-type-assertion
return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
}
var maxPoolGrad = op({
maxPoolGrad_: maxPoolGrad_
});
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolGradConfig = {
kernelName: MaxPool,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var _x = saved[0],
y = saved[1];
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad;
return {
x: function x() {
return maxPoolGrad(dy, _x, y, filterSize, strides, pad);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var meanGradConfig = {
kernelName: Mean,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0];
var axis = attrs.axis;
var axes = parseAxisParam(axis, x.shape);
var shapes = computeOutAndReduceShapes(x.shape, axes);
var reduceShape = shapes[1];
var reduceSize = sizeFromShape(reduceShape);
var derX = function derX() {
var expandedDyShape = x.shape.slice();
axes.forEach(function (axis) {
expandedDyShape[axis] = 1;
});
var expandedDy = reshape(dy, expandedDyShape);
var res = div(mul(expandedDy, ones$1(x.shape, 'float32')), reduceSize);
return res;
};
return {
x: derX
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var minGradConfig = {
kernelName: Min,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var minAttrs = attrs;
var axis = minAttrs.axis;
var x = saved[0],
y = saved[1];
var origAxes = parseAxisParam(axis, x.shape);
var minGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: function x() {
return minGrad['x']();
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var minimumGradConfig = {
kernelName: Minimum,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var derA = function derA() {
return mul(dy, cast(lessEqual(a, b), 'float32'));
};
var derB = function derB() {
return mul(dy, cast(greater(a, b), 'float32'));
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var mirrorPadGradConfig = {
kernelName: MirrorPad,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
// Pad introduces values around the original tensor, so the gradient
// slices the original shape out of the gradient.
var _x = saved[0];
var paddings = attrs.paddings;
var begin = paddings.map(function (p) {
return p[0];
});
return {
x: function x() {
return slice$2(dy, begin, _x.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var modGradConfig = {
kernelName: Mod,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(dy, reduceAxes), a.shape);
}
return dy;
};
var derB = function derB() {
var res = mul(dy, neg(floor$a(div(a, b))));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), b.shape);
}
return res;
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var multiplyGradConfig = {
kernelName: Multiply,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var res = mul(dy, cast(b, 'float32'));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function derB() {
var res = mul(dy, cast(a, 'float32'));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), b.shape);
}
return res;
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var negGradConfig = {
kernelName: Neg,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return neg(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var oneHotGradConfig = {
kernelName: OneHot,
inputsToSave: ['indices'],
gradFunc: function gradFunc(dy, saved) {
var _indices = saved[0];
return {
indices: function indices() {
return zeros(_indices.shape, 'float32');
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var onesLikeGradConfig = {
kernelName: OnesLike,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var packGradConfig = {
kernelName: Pack,
saveAllInputs: true,
gradFunc: function gradFunc(dy, saved, attrs) {
var axis = attrs.axis;
var derTensors = unstack(dy, axis);
return derTensors.map(function (t) {
return function () {
return t;
};
});
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var padV2GradConfig = {
kernelName: PadV2,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
// Pad introduces values around the original tensor, so the gradient
// slices the original shape out of the gradient.
var _x = saved[0];
var paddings = attrs.paddings;
var begin = paddings.map(function (p) {
return p[0];
});
return {
x: function x() {
return slice$2(dy, begin, _x.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var powGradConfig = {
kernelName: Pow,
inputsToSave: ['a', 'b'],
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1],
y = saved[2];
var base = a;
var exp = b;
var outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
var derBase = function derBase() {
var expFloat = cast(exp, 'float32');
var res = mul(dy, mul(expFloat, pow$5(base, sub(expFloat, scalar(1)))));
var reduceAxes = getReductionAxes(base.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, base.shape);
};
var derExp = function derExp() {
var condition = greater(base, 0);
var logBase = where(condition, log$a(base), zerosLike(base));
var res = mul(dy, mul(y, logBase));
var reduceAxes = getReductionAxes(exp.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, exp.shape);
};
return {
a: derBase,
b: derExp
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var preluGradConfig = {
kernelName: Prelu,
inputsToSave: ['x', 'alpha'],
gradFunc: function gradFunc(dy, saved) {
var x = saved[0],
_alpha = saved[1];
var mask = greater(x, 0);
return {
x: function x() {
return where(mask, dy, mul(dy, _alpha));
},
alpha: function alpha() {
var res = where(mask, zerosLike(dy), mul(dy, x));
var reduceAxes = getReductionAxes(_alpha.shape, dy.shape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, _alpha.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var divGradConfig = {
kernelName: RealDiv,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var res = div(dy, cast(b, 'float32'));
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape(sum$1(res, reduceAxes), a.shape);
}
return res;
};
var derB = function derB() {
var res = mul(dy, cast(a, 'float32'));
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape(sum$1(res, reduceAxes), b.shape);
}
var tmp = square(b);
return neg(div(res, cast(tmp, 'float32')));
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reciprocalGradConfig = {
kernelName: Reciprocal,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, neg(square(_x)));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu6GradConfig = {
kernelName: Relu6,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var x = saved[0];
var mask = mul(lessEqual(x, 6), step(x));
return {
x: function x() {
return mul(dy, cast(mask, 'float32'));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reluGradConfig = {
kernelName: Relu,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(dy, cast(step(_x), 'float32'));
}
};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reshapeGradConfig = {
kernelName: Reshape,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return reshape(dy, _x.shape);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var resizeBilinearGradConfig = {
kernelName: ResizeBilinear,
inputsToSave: ['images'],
gradFunc: function gradFunc(dy, saved, attrs) {
var images = saved[0];
var inputs = {
dy: dy,
images: images
};
var imagesDer = function imagesDer() {
return (// tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs)
);
};
return {
images: imagesDer
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var resizeNearestNeighborGradConfig = {
kernelName: ResizeNearestNeighbor,
inputsToSave: ['images'],
gradFunc: function gradFunc(dy, saved, attrs) {
var images = saved[0];
var inputs = {
dy: dy,
images: images
};
var imagesDer = function imagesDer() {
return (// tslint:disable-next-line: no-unnecessary-type-assertion
ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs)
);
};
return {
images: imagesDer
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reverseGradConfig = {
kernelName: Reverse,
gradFunc: function gradFunc(dy, saved, attrs) {
var dims = attrs.dims;
var axes = parseAxisParam(dims, dy.shape);
return {
x: function x() {
return reverse(dy, axes);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var roundGradConfig = {
kernelName: Round,
gradFunc: function gradFunc(dy) {
// TODO(nsthorat): Let gradients be null for cases where we want to stop
// backpropgation.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rsqrtGradConfig = {
kernelName: Rsqrt,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return neg(div(dy, mul(pow$5(_x, 1.5), 2)));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var selectGradConfig = {
kernelName: Select,
inputsToSave: ['condition'],
gradFunc: function gradFunc(dy, saved) {
var _condition = saved[0];
return {
// TODO(julianoks): Return null for condition gradient
// when backprop supports it.
condition: function condition() {
return cast(zerosLike(_condition), 'float32');
},
t: function t() {
return mul(dy, cast(_condition, dy.dtype));
},
e: function e() {
return mul(dy, cast(logicalNot(_condition), dy.dtype));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var seluGradConfig = {
kernelName: Selu,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
var mask = greater(_x, scalar(0));
var scaleAlpha = scalar(SELU_SCALEALPHA);
var scale = scalar(SELU_SCALE);
var greaterThanZeroDer = mul(dy, scale);
var lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$3(cast(_x, 'float32')));
return where(mask, greaterThanZeroDer, lessEqualZeroDer);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sigmoidGradConfig = {
kernelName: Sigmoid,
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved) {
var y = saved[0];
return {
x: function x() {
return mul(dy, mul(y, sub(scalar(1), y)));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var signGradConfig = {
kernelName: Sign,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinGradConfig = {
kernelName: Sin,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(cos(cast(_x, 'float32')), dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinhGradConfig = {
kernelName: Sinh,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(cosh(cast(_x, 'float32')), dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sliceGradConfig = {
kernelName: Slice,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0];
var begin = attrs.begin,
size = attrs.size;
var inputShape = x.shape;
var _parseSliceParams = parseSliceParams(x, begin, size),
begin_ = _parseSliceParams[0],
size_ = _parseSliceParams[1]; // Create an Nx2 padding where the first column represents how many
// zeros are prepended (at start) for each dimension, and the second
// column indicates how many zeros are appended (at end).
// The number of zeros to append is the shape of the input
// elementwise-subtracted by both the begin vector and sizes vector.
var paddings = [];
for (var i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {
x: function x() {
return pad(dy, paddings);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var softmaxGradConfig = {
kernelName: Softmax,
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved, attrs) {
var y = saved[0];
var dim = attrs.dim;
var keepDims = true;
var dyTimesY = mul(dy, y);
return {
logits: function logits() {
return sub(dyTimesY, mul(sum$1(dyTimesY, [dim], keepDims), y));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var softplusGradConfig = {
kernelName: Softplus,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(dy, sigmoid(_x));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var spaceToBatchNDGradConfig = {
kernelName: SpaceToBatchND,
gradFunc: function gradFunc(dy, saved, attrs) {
var blockShape = attrs.blockShape,
paddings = attrs.paddings;
return {
x: function x() {
return batchToSpaceND(dy, blockShape, paddings);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var splitVGradConfig = {
kernelName: SplitV,
gradFunc: function gradFunc(dy, saved, attrs) {
var axis = attrs.axis;
return {
x: function x() {
return concat(dy, axis);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sqrtGradConfig = {
kernelName: Sqrt,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, mul(sqrt$3(cast(_x, 'float32')), 2));
}
};
}
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squareGradConfig = {
kernelName: Square,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return mul(dy, mul(cast(_x, 'float32'), 2));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squaredDifferenceGradConfig = {
kernelName: SquaredDifference,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var two = scalar(2);
var derA = function derA() {
return mul(dy, mul(two, sub(a, b)));
};
var derB = function derB() {
return mul(dy, mul(two, sub(b, a)));
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var stepGradConfig = {
kernelName: Step,
gradFunc: function gradFunc(dy) {
// TODO(manrajgrover): Return null for gradients when backprop supports
// it.
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var subGradConfig = {
kernelName: Sub,
inputsToSave: ['a', 'b'],
gradFunc: function gradFunc(dy, saved) {
var a = saved[0],
b = saved[1];
var outShape = assertAndGetBroadcastShape(a.shape, b.shape);
var derA = function derA() {
var res = dy;
var reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(res, a.shape);
};
var derB = function derB() {
var res = dy;
var reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$1(res, reduceAxes);
}
return reshape(neg(res), b.shape);
};
return {
a: derA,
b: derB
};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sumGradConfig = {
kernelName: Sum,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0];
var expandedDyShape = x.shape.slice();
var axis = attrs.axis;
var axes = parseAxisParam(axis, x.shape);
axes.forEach(function (axis) {
expandedDyShape[axis] = 1;
});
var expandedDy = reshape(dy, expandedDyShape);
var derX = mul(expandedDy, ones$1(x.shape, 'float32'));
return {
x: function x() {
return derX;
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanGradConfig = {
kernelName: Tan,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved) {
var _x = saved[0];
return {
x: function x() {
return div(dy, square(cos(_x)));
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanhGradConfig = {
kernelName: Tanh,
outputsToSave: [true],
gradFunc: function gradFunc(dy, saved) {
var y = saved[0];
return {
x: function x() {
return mul(sub(scalar(1), square(y)), dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tileGradConfig = {
kernelName: Tile,
inputsToSave: ['x'],
gradFunc: function gradFunc(dy, saved, attrs) {
var x = saved[0];
var reps = attrs.reps;
var derX = function derX() {
var xGrad = zerosLike(x); // TODO(cais): Maybe reduce memory footprint by avoiding repeated
// slicing.
if (x.rank === 1) {
for (var i = 0; i < reps[0]; ++i) {
xGrad = add$1(xGrad, slice$2(dy, [i * x.shape[0]], [x.shape[0]]));
}
} else if (x.rank === 2) {
for (var _i = 0; _i < reps[0]; ++_i) {
for (var j = 0; j < reps[1]; ++j) {
xGrad = add$1(xGrad, slice$2(dy, [_i * x.shape[0], j * x.shape[1]], [x.shape[0], x.shape[1]]));
}
}
} else if (x.rank === 3) {
for (var _i2 = 0; _i2 < reps[0]; ++_i2) {
for (var _j = 0; _j < reps[1]; ++_j) {
for (var k = 0; k < reps[2]; ++k) {
xGrad = add$1(xGrad, slice$2(dy, [_i2 * x.shape[0], _j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
}
}
}
} else if (x.rank === 4) {
for (var _i3 = 0; _i3 < reps[0]; ++_i3) {
for (var _j2 = 0; _j2 < reps[1]; ++_j2) {
for (var _k = 0; _k < reps[2]; ++_k) {
for (var l = 0; l < reps[3]; ++l) {
xGrad = add$1(xGrad, slice$2(dy, [_i3 * x.shape[0], _j2 * x.shape[1], _k * x.shape[2], l * x.shape[3]], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
}
}
}
}
} else {
throw new Error("Gradient for tile operation is not implemented for rank-" + (x.rank + " tensors yet."));
}
return xGrad;
};
return {
x: derX
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var transposeGradConfig = {
kernelName: Transpose,
gradFunc: function gradFunc(dy, saved, attrs) {
var transposeAttrs = attrs;
var perm = transposeAttrs.perm;
var undoPerm = getUndoAxesPermutation(perm);
return {
x: function x() {
return transpose(dy, undoPerm);
}
};
}
};
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var unpackGradConfig = {
kernelName: Unpack,
gradFunc: function gradFunc(dy, saved, attrs) {
var unpackAttrs = attrs;
var axis = unpackAttrs.axis;
return {
value: function value() {
return stack(dy, axis);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var unsortedSegmentSumGradConfig = {
kernelName: UnsortedSegmentSum,
inputsToSave: ['segmentIds'],
gradFunc: function gradFunc(dy, saved) {
var segmentIds = saved[0];
var derX = function derX() {
return gatherDropNegatives(dy, segmentIds);
};
return {
x: derX
};
}
};
function gatherDropNegatives(x, indices) {
// Helper function for unsorted segment ops. Gathers params for
// positive segment ids and gathers 0 for inputs with negative segment id.
// Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py
var zeroClippedIndices = maximum(indices, zerosLike(indices));
var gathered = gather(x, zeroClippedIndices);
var isPositive = greaterEqual(indices, scalar(0, 'int32'));
var numIters = gathered.rank - isPositive.rank;
for (var i = 0; i < numIters; ++i) {
isPositive = expandDims(isPositive, i + 1);
}
isPositive = logicalAnd(isPositive, ones$1(gathered.shape, 'bool'));
var zeroSlice = zerosLike(gathered);
return where(isPositive, gathered, zeroSlice);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var zerosLikeGradConfig = {
kernelName: ZerosLike,
gradFunc: function gradFunc(dy) {
return {
x: function x() {
return zerosLike(dy);
}
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var gradConfigs = [absGradConfig, acosGradConfig, acoshGradConfig, addGradConfig, addNGradConfig, argMaxGradConfig, argMinGradConfig, asinGradConfig, asinhGradConfig, atan2GradConfig, atanGradConfig, atanhGradConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulGradConfig, batchToSpaceNDGradConfig, broadcastToGradConfig, castGradConfig, ceilGradConfig, clipByValueGradConfig, complexAbsGradConfig, concatGradConfig, conv2DBackpropInputGradConfig, conv2DGradConfig, conv3DGradConfig, cosGradConfig, coshGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, dilation2dGradConfig, divGradConfig, eluGradConfig, erfGradConfig, expGradConfig, expandDimsGradConfig, expm1GradConfig, floorDivGradConfig, floorGradConfig, fusedBatchNormGradConfig, gatherGradConfig, greaterEqualGradConfig, identityGradConfig, isFiniteGradConfig, isInfGradConfig, isNanGradConfig, leakyReluGradConfig, log1pGradConfig, logGradConfig, logSoftmaxGradConfig, lrnGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, maxPool3DGradConfig, maxPoolGradConfig, meanGradConfig, minGradConfig, minimumGradConfig, mirrorPadGradConfig, modGradConfig, multiplyGradConfig, negGradConfig, oneHotGradConfig, onesLikeGradConfig, packGradConfig, padV2GradConfig, padV2GradConfig, powGradConfig, preluGradConfig, reciprocalGradConfig, relu6GradConfig, reluGradConfig, reshapeGradConfig, resizeBilinearGradConfig, resizeNearestNeighborGradConfig, reverseGradConfig, roundGradConfig, rsqrtGradConfig, selectGradConfig, seluGradConfig, sigmoidGradConfig, signGradConfig, sinGradConfig, sinhGradConfig, sliceGradConfig, softmaxGradConfig, softplusGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, splitVGradConfig, sqrtGradConfig, squaredDifferenceGradConfig, squareGradConfig, stepGradConfig, subGradConfig, sumGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, transposeGradConfig, unpackGradConfig, unsortedSegmentSumGradConfig, zerosLikeGradConfig];
for (var _i = 0, _gradConfigs = gradConfigs; _i < _gradConfigs.length; _i++) {
var gradientConfig = _gradConfigs[_i];
registerGradient(gradientConfig);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.abs = function () {
this.throwIfDisposed();
return abs$8(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.acos = function () {
this.throwIfDisposed();
return acos(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.acosh = function () {
this.throwIfDisposed();
return acosh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.add = function (b) {
this.throwIfDisposed();
return add$1(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.all = function (axis, keepDims) {
this.throwIfDisposed();
return all(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.any = function (axis, keepDims) {
this.throwIfDisposed();
return any(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.argMax = function (axis) {
this.throwIfDisposed();
return argMax(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.argMin = function (axis) {
this.throwIfDisposed();
return argMin(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a size-1 `tf.Tensor` to a `tf.Scalar`.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.asScalar = function () {
this.throwIfDisposed();
assert(this.size === 1, function () {
return 'The array must have only 1 element.';
});
return reshape(this, []);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Casts a `tf.Tensor` to a specified dtype.
*
* @param dtype Data-type to cast the tensor to.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.asType = function (dtype) {
this.throwIfDisposed();
return cast(this, dtype);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a `tf.Tensor` to a `tf.Tensor1D`.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.as1D = function () {
this.throwIfDisposed();
return reshape(this, [this.size]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a `tf.Tensor` to a `tf.Tensor2D`.
*
* @param rows Number of rows in `tf.Tensor2D`.
* @param columns Number of columns in `tf.Tensor2D`.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.as2D = function (rows, columns) {
this.throwIfDisposed();
return reshape(this, [rows, columns]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a `tf.Tensor` to a `tf.Tensor3D`.
*
* @param rows Number of rows in `tf.Tensor3D`.
* @param columns Number of columns in `tf.Tensor3D`.
* @param depth Depth of `tf.Tensor3D`.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.as3D = function (rows, columns, depth) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a `tf.Tensor` to a `tf.Tensor4D`.
*
* @param rows Number of rows in `tf.Tensor4D`.
* @param columns Number of columns in `tf.Tensor4D`.
* @param depth Depth of `tf.Tensor4D`.
* @param depth2 4th dimension of `tf.Tensor4D`.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.as4D = function (rows, columns, depth, depth2) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Converts a `tf.Tensor` to a `tf.Tensor5D`.
*
* @param rows Number of rows in `tf.Tensor5D`.
* @param columns Number of columns in `tf.Tensor5D`.
* @param depth Depth of `tf.Tensor5D`.
* @param depth2 4th dimension of `tf.Tensor5D`.
* @param depth3 5th dimension of 'tf.Tensor5D'
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.as5D = function (rows, columns, depth, depth2, depth3) {
this.throwIfDisposed();
return reshape(this, [rows, columns, depth, depth2, depth3]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.asin = function () {
this.throwIfDisposed();
return asin(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.asinh = function () {
this.throwIfDisposed();
return asinh$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.atan = function () {
this.throwIfDisposed();
return atan(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.atan2 = function (b) {
this.throwIfDisposed();
return atan2(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.atanh = function () {
this.throwIfDisposed();
return atanh(this);
};
getGlobalTensorClass().prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) {
this.throwIfDisposed();
return avgPool(this, filterSize, strides, pad, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.batchToSpaceND = function (blockShape, crops) {
this.throwIfDisposed();
return batchToSpaceND(this, blockShape, crops);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) {
this.throwIfDisposed();
return batchNorm(this, mean, variance, offset, scale, varianceEpsilon);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.broadcastTo = function (shape) {
this.throwIfDisposed();
return broadcastTo(this, shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.cast = function (dtype) {
this.throwIfDisposed();
return cast(this, dtype);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.ceil = function () {
this.throwIfDisposed();
return ceil$3(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.clipByValue = function (min, max) {
this.throwIfDisposed();
return clipByValue(this, min, max);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.concat = function (x, axis) {
this.throwIfDisposed();
if (x instanceof Tensor) {
x = [x];
}
return concat([this].concat(x), axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) {
this.throwIfDisposed();
return conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) {
this.throwIfDisposed();
return conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
this.throwIfDisposed();
return conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.cos = function () {
this.throwIfDisposed();
return cos(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.cosh = function () {
this.throwIfDisposed();
return cosh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.cumsum = function (axis, exclusive, reverse) {
this.throwIfDisposed();
return cumsum(this, axis, exclusive, reverse);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.depthToSpace = function (blockSize, dataFormat) {
this.throwIfDisposed();
return depthToSpace(this, blockSize, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.depthwiseConv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) {
this.throwIfDisposed();
return depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.dilation2d = function (filter, strides, pad, dilations, dataFormat) {
this.throwIfDisposed();
return dilation2d(this, filter, strides, pad, dilations, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.divNoNan = function (b) {
this.throwIfDisposed();
return divNoNan(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.div = function (b) {
this.throwIfDisposed();
return div(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.dot = function (b) {
this.throwIfDisposed();
return dot(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.elu = function () {
this.throwIfDisposed();
return elu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.equal = function (b) {
this.throwIfDisposed();
return equal(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.erf = function () {
this.throwIfDisposed();
return erf(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.exp = function () {
this.throwIfDisposed();
return exp$3(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.expandDims = function (axis) {
this.throwIfDisposed();
return expandDims(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.expm1 = function () {
this.throwIfDisposed();
return expm1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.fft = function () {
this.throwIfDisposed();
return fft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Flatten a Tensor to a 1D array.
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.flatten = function () {
this.throwIfDisposed();
return reshape(this, [this.size]);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.floor = function () {
this.throwIfDisposed();
return floor$a(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.floorDiv = function (b) {
this.throwIfDisposed();
return floorDiv(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.gather = function (indices, axis) {
this.throwIfDisposed();
return gather(this, indices, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.greaterEqual = function (b) {
this.throwIfDisposed();
return greaterEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.greater = function (b) {
this.throwIfDisposed();
return greater(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.ifft = function () {
this.throwIfDisposed();
return ifft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.irfft = function () {
this.throwIfDisposed();
return irfft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.isFinite = function () {
this.throwIfDisposed();
return isFinite$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.isInf = function () {
this.throwIfDisposed();
return isInf(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.isNaN = function () {
this.throwIfDisposed();
return isNaN$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.leakyRelu = function (alpha) {
this.throwIfDisposed();
return leakyRelu(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.lessEqual = function (b) {
this.throwIfDisposed();
return lessEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.less = function (b) {
this.throwIfDisposed();
return less(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.localResponseNormalization = function (depthRadius, bias, alpha, beta) {
this.throwIfDisposed();
return localResponseNormalization(this, depthRadius, bias, alpha, beta);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logSigmoid = function () {
this.throwIfDisposed();
return logSigmoid(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logSoftmax = function (axis) {
this.throwIfDisposed();
return logSoftmax(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logSumExp = function (axis, keepDims) {
this.throwIfDisposed();
return logSumExp(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.log = function () {
this.throwIfDisposed();
return log$a(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.log1p = function () {
this.throwIfDisposed();
return log1p(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logicalAnd = function (b) {
this.throwIfDisposed();
return logicalAnd(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logicalNot = function () {
this.throwIfDisposed();
return logicalNot(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logicalOr = function (b) {
this.throwIfDisposed();
return logicalOr(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.logicalXor = function (b) {
this.throwIfDisposed();
return logicalXor(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.matMul = function (b, transposeA, transposeB) {
this.throwIfDisposed();
return matMul(this, b, transposeA, transposeB);
};
getGlobalTensorClass().prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) {
this.throwIfDisposed();
return maxPool(this, filterSize, strides, pad, dimRoundingMode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.max = function (axis, keepDims) {
this.throwIfDisposed();
return max$5(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.maximum = function (b) {
this.throwIfDisposed();
return maximum(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.mean = function (axis, keepDims) {
this.throwIfDisposed();
return mean(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.min = function (axis, keepDims) {
this.throwIfDisposed();
return min$9(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.minimum = function (b) {
this.throwIfDisposed();
return minimum(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.mirrorPad = function (paddings, mode) {
this.throwIfDisposed();
return mirrorPad(this, paddings, mode);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.mod = function (b) {
this.throwIfDisposed();
return mod(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.mul = function (b) {
this.throwIfDisposed();
return mul(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.neg = function () {
this.throwIfDisposed();
return neg(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.norm = function (ord, axis, keepDims) {
this.throwIfDisposed();
return norm(this, ord, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.notEqual = function (b) {
this.throwIfDisposed();
return notEqual(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.oneHot = function (depth, onValue, offValue) {
if (onValue === void 0) {
onValue = 1;
}
if (offValue === void 0) {
offValue = 0;
}
this.throwIfDisposed();
return oneHot(this, depth, onValue, offValue);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.onesLike = function () {
this.throwIfDisposed();
return onesLike(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.pad = function (paddings, constantValue) {
this.throwIfDisposed();
return pad(this, paddings, constantValue);
};
getGlobalTensorClass().prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides) {
this.throwIfDisposed();
return pool(this, windowShape, poolingType, padding, dilationRate, strides);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.pow = function (exp) {
this.throwIfDisposed();
return pow$5(this, exp);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.prelu = function (alpha) {
this.throwIfDisposed();
return prelu(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.prod = function (axis, keepDims) {
this.throwIfDisposed();
return prod(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.reciprocal = function () {
this.throwIfDisposed();
return reciprocal(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.relu = function () {
this.throwIfDisposed();
return relu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.relu6 = function () {
this.throwIfDisposed();
return relu6(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Reshapes the tensor into the shape of the provided tensor.
*
* @param x The tensor of required shape.
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.reshapeAs = function (x) {
this.throwIfDisposed();
return reshape(this, x.shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.reshape = function (shape) {
this.throwIfDisposed();
return reshape(this, shape);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.resizeBilinear = function (newShape2D, alignCorners, halfPixelCenters) {
this.throwIfDisposed();
return resizeBilinear(this, newShape2D, alignCorners, halfPixelCenters);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.resizeNearestNeighbor = function (newShape2D, alignCorners, halfFloatCenters) {
this.throwIfDisposed();
return resizeNearestNeighbor(this, newShape2D, alignCorners, halfFloatCenters);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.reverse = function (axis) {
this.throwIfDisposed();
return reverse(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.rfft = function () {
this.throwIfDisposed();
return rfft(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.round = function () {
this.throwIfDisposed();
return round$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.rsqrt = function () {
this.throwIfDisposed();
return rsqrt(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.selu = function () {
this.throwIfDisposed();
return selu(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) {
this.throwIfDisposed();
return separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sigmoid = function () {
this.throwIfDisposed();
return sigmoid(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sign = function () {
this.throwIfDisposed();
return sign(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sin = function () {
this.throwIfDisposed();
return sin(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sinh = function () {
this.throwIfDisposed();
return sinh(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.slice = function (begin, size) {
this.throwIfDisposed();
return slice$2(this, begin, size);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.softmax = function (dim) {
this.throwIfDisposed();
return softmax(this, dim);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.softplus = function () {
this.throwIfDisposed();
return softplus(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.spaceToBatchND = function (blockShape, paddings) {
this.throwIfDisposed();
return spaceToBatchND(this, blockShape, paddings);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.split = function (numOrSizeSplits, axis) {
this.throwIfDisposed();
return split$1(this, numOrSizeSplits, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sqrt = function () {
this.throwIfDisposed();
return sqrt$3(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.square = function () {
this.throwIfDisposed();
return square(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.squaredDifference = function (b) {
this.throwIfDisposed();
return squaredDifference(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.squeeze = function (axis) {
this.throwIfDisposed();
return squeeze(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.stack = function (x, axis) {
this.throwIfDisposed();
var tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this].concat(x);
return stack(tensorsToBeStacked, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.step = function (alpha) {
this.throwIfDisposed();
return step(this, alpha);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
this.throwIfDisposed();
return stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sub = function (b) {
this.throwIfDisposed();
return sub(this, b);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.sum = function (axis, keepDims) {
this.throwIfDisposed();
return sum$1(this, axis, keepDims);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.tan = function () {
this.throwIfDisposed();
return tan(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.tanh = function () {
this.throwIfDisposed();
return tanh$1(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.tile = function (reps) {
this.throwIfDisposed();
return tile(this, reps);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Casts the array to type `bool`
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.toBool = function () {
this.throwIfDisposed();
return cast(this, 'bool');
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Casts the array to type `float32`
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.toFloat = function () {
this.throwIfDisposed();
return cast(this, 'float32');
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Casts the array to type `int32`
*
* @doc {heading: 'Tensors', subheading: 'Classes'}
*/
getGlobalTensorClass().prototype.toInt = function () {
this.throwIfDisposed();
return cast(this, 'int32');
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.topk = function (k, sorted) {
this.throwIfDisposed();
return topk(this, k, sorted);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.transpose = function (perm) {
this.throwIfDisposed();
return transpose(this, perm);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.unique = function (axis) {
this.throwIfDisposed();
return unique(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.unsortedSegmentSum = function (segmentIds, numSegments) {
this.throwIfDisposed();
return unsortedSegmentSum(this, segmentIds, numSegments);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.unstack = function (axis) {
this.throwIfDisposed();
return unstack(this, axis);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.where = function (condition, x) {
this.throwIfDisposed();
return where(condition, this, x);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
getGlobalTensorClass().prototype.zerosLike = function () {
this.throwIfDisposed();
return zerosLike(this);
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var _epsilon;
/**
* Returns the value of the fuzz factor used in numeric expressions.
*/
function epsilon() {
if (_epsilon == null) {
_epsilon = backend().epsilon();
}
return _epsilon;
}
/**
* Sets the value of the fuzz factor used in numeric expressions.
* @param e New value of epsilon.
*/
function setEpsilon(e) {
_epsilon = e;
}
/**
* Returns the default image data format convention.
*/
function imageDataFormat() {
return 'channelsLast';
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Explicit error types.
*
* See the following link for more information about why the code includes
* calls to setPrototypeOf:
*
* https://github.com/Microsoft/TypeScript-wiki/blob/master/Breaking-Changes.md#extending-built-ins-like-error-array-and-map-may-no-longer-work
*/
// tslint:enable
/**
* Equivalent of Python's AttributeError.
*/
var AttributeError = /*#__PURE__*/function (_Error) {
_inheritsLoose(AttributeError, _Error);
function AttributeError(message) {
var _this;
_this = _Error.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this), AttributeError.prototype);
return _this;
}
return AttributeError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* Equivalent of Python's RuntimeError.
*/
var RuntimeError = /*#__PURE__*/function (_Error2) {
_inheritsLoose(RuntimeError, _Error2);
function RuntimeError(message) {
var _this2;
_this2 = _Error2.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this2), RuntimeError.prototype);
return _this2;
}
return RuntimeError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* Equivalent of Python's ValueError.
*/
var ValueError = /*#__PURE__*/function (_Error3) {
_inheritsLoose(ValueError, _Error3);
function ValueError(message) {
var _this3;
_this3 = _Error3.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this3), ValueError.prototype);
return _this3;
}
return ValueError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* Equivalent of Python's NotImplementedError.
*/
var NotImplementedError = /*#__PURE__*/function (_Error4) {
_inheritsLoose(NotImplementedError, _Error4);
function NotImplementedError(message) {
var _this4;
_this4 = _Error4.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this4), NotImplementedError.prototype);
return _this4;
}
return NotImplementedError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* Equivalent of Python's AssertionError.
*/
var AssertionError = /*#__PURE__*/function (_Error5) {
_inheritsLoose(AssertionError, _Error5);
function AssertionError(message) {
var _this5;
_this5 = _Error5.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this5), AssertionError.prototype);
return _this5;
}
return AssertionError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* Equivalent of Python's IndexError.
*/
var IndexError = /*#__PURE__*/function (_Error6) {
_inheritsLoose(IndexError, _Error6);
function IndexError(message) {
var _this6;
_this6 = _Error6.call(this, message) || this; // Set the prototype explicitly.
Object.setPrototypeOf(_assertThisInitialized(_this6), IndexError.prototype);
return _this6;
}
return IndexError;
}( /*#__PURE__*/_wrapNativeSuper(Error));
/**
* If `value` is an Array, equivalent to Python's `value * numValues`.
* If `value` is not an Array, equivalent to Python's `[value] * numValues`
*/
// tslint:disable-next-line:no-any
function pyListRepeat(value, numValues) {
if (Array.isArray(value)) {
// tslint:disable-next-line:no-any
var newArray = [];
for (var i = 0; i < numValues; i++) {
newArray = newArray.concat(value);
}
return newArray;
} else {
var _newArray = new Array(numValues);
_newArray.fill(value);
return _newArray;
}
}
function assert$1(val, message) {
if (!val) {
throw new AssertionError(message);
}
}
/**
* Count the number of elements of the `array` that are equal to `reference`.
*/
function count(array, refernce) {
var counter = 0;
for (var _iterator = _createForOfIteratorHelperLoose(array), _step; !(_step = _iterator()).done;) {
var item = _step.value;
if (item === refernce) {
counter++;
}
}
return counter;
}
/**
* If an array is of length 1, just return the first element. Otherwise, return
* the full array.
* @param tensors
*/
function singletonOrArray(xs) {
if (xs.length === 1) {
return xs[0];
}
return xs;
}
/**
* Normalizes a list/tensor into a list.
*
* If a tensor is passed, we return
* a list of size 1 containing the tensor.
*
* @param x target object to be normalized.
*/
// tslint:disable-next-line:no-any
function toList(x) {
if (Array.isArray(x)) {
return x;
}
return [x];
}
/**
* Generate a UID for a list
*/
// tslint:disable-next-line:no-any
function objectListUid(objs) {
var objectList = toList(objs);
var retVal = '';
for (var _iterator2 = _createForOfIteratorHelperLoose(objectList), _step2; !(_step2 = _iterator2()).done;) {
var obj = _step2.value;
if (obj.id == null) {
throw new ValueError("Object " + obj + " passed to objectListUid without an id");
}
if (retVal !== '') {
retVal = retVal + ', ';
}
retVal = "" + retVal + Math.abs(obj.id);
}
return retVal;
}
/**
* Converts string to snake-case.
* @param name
*/
function toSnakeCase(name) {
var intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
var insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
/*
If the class is private the name starts with "_" which is not secure
for creating scopes. We prefix the name with "private" in this case.
*/
if (insecure[0] !== '_') {
return insecure;
}
return 'private' + insecure;
}
function toCamelCase(identifier) {
// quick return for empty string or single character strings
if (identifier.length <= 1) {
return identifier;
} // Check for the underscore indicating snake_case
if (identifier.indexOf('_') === -1) {
return identifier;
}
return identifier.replace(/[_]+(\w|$)/g, function (m, p1) {
return p1.toUpperCase();
});
} // tslint:disable-next-line:no-any
var _GLOBAL_CUSTOM_OBJECTS = {};
function serializeKerasObject(instance) {
if (instance === null || instance === undefined) {
return null;
}
var dict = {};
dict['className'] = instance.getClassName();
dict['config'] = instance.getConfig();
return dict;
}
/**
* Replace ndarray-style scalar objects in serialization objects with numbers.
*
* Background: In some versions of tf.keras, certain scalar values in the HDF5
* model save file can be serialized as: `{'type': 'ndarray', 'value': num}`,
* where in `num` is a plain number. This method converts such serialization
* to a `number`.
*
* @param config The keras-format serialization object to be processed
* (in place).
*/
function convertNDArrayScalarsInConfig(config) {
if (config == null || typeof config !== 'object') {
return;
} else if (Array.isArray(config)) {
config.forEach(function (configItem) {
return convertNDArrayScalarsInConfig(configItem);
});
} else {
var fields = Object.keys(config);
for (var _i = 0, _fields = fields; _i < _fields.length; _i++) {
var field = _fields[_i];
var value = config[field];
if (value != null && typeof value === 'object') {
if (!Array.isArray(value) && value['type'] === 'ndarray' && typeof value['value'] === 'number') {
config[field] = value['value'];
} else {
convertNDArrayScalarsInConfig(value);
}
}
}
}
}
/**
* Deserialize a saved Keras Object
* @param identifier either a string ID or a saved Keras dictionary
* @param moduleObjects a list of Python class names to object constructors
* @param customObjects a list of Python class names to object constructors
* @param printableModuleName debug text for the object being reconstituted
* @param fastWeightInit Optional flag to use fast weight initialization
* during deserialization. This is applicable to cases in which
* the initialization will be immediately overwritten by loaded weight
* values. Default: `false`.
* @returns a TensorFlow.js Layers object
*/
// tslint:disable:no-any
function deserializeKerasObject(identifier, moduleObjects, customObjects, printableModuleName, fastWeightInit) {
if (moduleObjects === void 0) {
moduleObjects = {};
}
if (customObjects === void 0) {
customObjects = {};
}
if (printableModuleName === void 0) {
printableModuleName = 'object';
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
// tslint:enable
if (typeof identifier === 'string') {
var functionName = identifier;
var fn;
if (functionName in customObjects) {
fn = customObjects[functionName];
} else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
} else {
fn = moduleObjects[functionName];
if (fn == null) {
throw new ValueError("Unknown " + printableModuleName + ": " + identifier + ". " + "This may be due to one of the following reasons:\n" + ("1. The " + printableModuleName + " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + ("2. The custom " + printableModuleName + " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass()."); // TODO(cais): Add link to tutorial page on custom layers.
}
}
return fn;
} else {
// In this case we are dealing with a Keras config dictionary.
var config = identifier;
if (config['className'] == null || config['config'] == null) {
throw new ValueError(printableModuleName + ": Improper config format: " + (JSON.stringify(config) + ".\n") + "'className' and 'config' must set.");
}
var className = config['className'];
var cls, fromConfig;
if (className in customObjects) {
var _customObjects$classN = customObjects[className];
cls = _customObjects$classN[0];
fromConfig = _customObjects$classN[1];
} else if (className in _GLOBAL_CUSTOM_OBJECTS) {
var _GLOBAL_CUSTOM_OBJECT = _GLOBAL_CUSTOM_OBJECTS['className'];
cls = _GLOBAL_CUSTOM_OBJECT[0];
fromConfig = _GLOBAL_CUSTOM_OBJECT[1];
} else if (className in moduleObjects) {
var _moduleObjects$classN = moduleObjects[className];
cls = _moduleObjects$classN[0];
fromConfig = _moduleObjects$classN[1];
}
if (cls == null) {
throw new ValueError("Unknown " + printableModuleName + ": " + className + ". " + "This may be due to one of the following reasons:\n" + ("1. The " + printableModuleName + " is defined in Python, in which ") + "case it needs to be ported to TensorFlow.js or your JavaScript " + "code.\n" + ("2. The custom " + printableModuleName + " is defined in JavaScript, ") + "but is not registered properly with " + "tf.serialization.registerClass()."); // TODO(cais): Add link to tutorial page on custom layers.
}
if (fromConfig != null) {
// Porting notes: Instead of checking to see whether fromConfig accepts
// customObjects, we create a customObjects dictionary and tack it on to
// config['config'] as config['config'].customObjects. Objects can use it,
// if they want.
// tslint:disable-next-line:no-any
var customObjectsCombined = {};
for (var _i2 = 0, _Object$keys = Object.keys(_GLOBAL_CUSTOM_OBJECTS); _i2 < _Object$keys.length; _i2++) {
var key = _Object$keys[_i2];
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
}
for (var _i3 = 0, _Object$keys2 = Object.keys(customObjects); _i3 < _Object$keys2.length; _i3++) {
var _key = _Object$keys2[_i3];
customObjectsCombined[_key] = customObjects[_key];
} // Add the customObjects to config
var nestedConfig = config['config'];
nestedConfig['customObjects'] = customObjectsCombined;
var backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (var _i4 = 0, _Object$keys3 = Object.keys(customObjects); _i4 < _Object$keys3.length; _i4++) {
var _key2 = _Object$keys3[_i4];
_GLOBAL_CUSTOM_OBJECTS[_key2] = customObjects[_key2];
}
convertNDArrayScalarsInConfig(config['config']);
var returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
} else {
// Then `cls` may be a function returning a class.
// In this case by convention `config` holds
// the kwargs of the function.
var _backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (var _i5 = 0, _Object$keys4 = Object.keys(customObjects); _i5 < _Object$keys4.length; _i5++) {
var _key3 = _Object$keys4[_i5];
_GLOBAL_CUSTOM_OBJECTS[_key3] = customObjects[_key3];
} // In python this is **config['config'], for tfjs-layers we require
// classes that use this fall-through construction method to take
// a config interface that mimics the expansion of named parameters.
var _returnObj = new cls(config['config']);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, _backupCustomObjects);
return _returnObj;
}
}
}
/**
* Compares two numbers for sorting.
* @param a
* @param b
*/
function numberCompare(a, b) {
return a < b ? -1 : a > b ? 1 : 0;
}
/**
* Comparison of two numbers for reverse sorting.
* @param a
* @param b
*/
function reverseNumberCompare(a, b) {
return -1 * numberCompare(a, b);
}
/**
* Convert a string into the corresponding DType.
* @param dtype
* @returns An instance of DType.
*/
function stringToDType(dtype) {
switch (dtype) {
case 'float32':
return 'float32';
default:
throw new ValueError("Invalid dtype: " + dtype);
}
}
/**
* Test the element-by-element equality of two Arrays of strings.
* @param xs First array of strings.
* @param ys Second array of strings.
* @returns Wether the two arrays are all equal, element by element.
*/
function stringsEqual(xs, ys) {
if (xs == null || ys == null) {
return xs === ys;
}
if (xs.length !== ys.length) {
return false;
}
for (var i = 0; i < xs.length; ++i) {
if (xs[i] !== ys[i]) {
return false;
}
}
return true;
}
/**
* Get the unique elements of an array.
* @param xs Array.
* @returns An Array consisting of the unique elements in `xs`.
*/
function unique$1(xs) {
if (xs == null) {
return xs;
}
var out = []; // TODO(cais): Maybe improve performance by sorting.
for (var _iterator3 = _createForOfIteratorHelperLoose(xs), _step3; !(_step3 = _iterator3()).done;) {
var x = _step3.value;
if (out.indexOf(x) === -1) {
out.push(x);
}
}
return out;
}
/**
* Determine if an Object is empty (i.e., does not have own properties).
* @param obj Object
* @returns Whether the Object is empty.
* @throws ValueError: If object is `null` or `undefined`.
*/
function isObjectEmpty(obj) {
if (obj == null) {
throw new ValueError("Invalid value in obj: " + JSON.stringify(obj));
}
for (var key in obj) {
if (obj.hasOwnProperty(key)) {
return false;
}
}
return true;
}
/**
* Helper function used to build type union/enum run-time checkers.
* @param values The list of allowed values.
* @param label A string name for the type
* @param value The value to test.
* @throws ValueError: If the value is not in values nor `undefined`/`null`.
*/
function checkStringTypeUnionValue(values, label, value) {
if (value == null) {
return;
}
if (values.indexOf(value) < 0) {
throw new ValueError(value + " is not a valid " + label + ". Valid values are " + values + " or null/undefined.");
}
}
/**
* Helper function for verifying the types of inputs.
*
* Ensures that the elements of `x` are all of type `expectedType`.
* Also verifies that the length of `x` is within bounds.
*
* @param x Object to test.
* @param expectedType The string expected type of all of the elements in the
* Array.
* @param minLength Return false if x.length is less than this.
* @param maxLength Return false if x.length is greater than this.
* @returns true if and only if `x` is an `Array<expectedType>` with
* length >= `minLength` and <= `maxLength`.
*/
// tslint:disable:no-any
function checkArrayTypeAndLength(x, expectedType, minLength, maxLength) {
if (minLength === void 0) {
minLength = 0;
}
if (maxLength === void 0) {
maxLength = Infinity;
}
assert$1(minLength >= 0);
assert$1(maxLength >= minLength);
return Array.isArray(x) && x.length >= minLength && x.length <= maxLength && x.every(function (e) {
return typeof e === expectedType;
});
} // tslint:enable:no-any
/**
* Assert that a value or an array of value are positive integer.
*
* @param value The value being asserted on. May be a single number or an array
* of numbers.
* @param name Name of the value, used to make the error message.
*/
function assertPositiveInteger(value, name) {
if (Array.isArray(value)) {
assert(value.length > 0, function () {
return name + " is unexpectedly an empty array.";
});
value.forEach(function (v, i) {
return assertPositiveInteger(v, "element " + (i + 1) + " of " + name);
});
} else {
assert(Number.isInteger(value) && value > 0, function () {
return "Expected " + name + " to be a positive integer, but got " + (formatAsFriendlyString(value) + ".");
});
}
}
/**
* Format a value into a display-friendly, human-readable fashion.
*
* - `null` is formatted as `'null'`
* - Strings are formated with flanking pair of quotes.
* - Arrays are formatted with flanking pair of square brackets.
*
* @param value The value to display.
* @return Formatted string.
*/
// tslint:disable-next-line:no-any
function formatAsFriendlyString(value) {
if (value === null) {
return 'null';
} else if (Array.isArray(value)) {
return '[' + value.map(function (v) {
return formatAsFriendlyString(v);
}).join(',') + ']';
} else if (typeof value === 'string') {
return "\"" + value + "\"";
} else {
return "" + value;
}
}
/**
* Returns a function `f2` (decorator) which wraps the original function
* `f`. `f2` guarantees that `f` can be called at most once
* every `waitMs` ms. If `f2` is called more often, it will return
* the last returned result of `f`.
*
* @param f The original function `f` to wrap.
* @param waitMs The time between two consecutive calls to `f` in ms.
*/
function debounce(f, waitMs) {
var lastTime = now();
var lastResult;
var f2 = function f2() {
var now$1 = now();
if (now$1 - lastTime < waitMs) {
return lastResult;
}
lastTime = now$1;
lastResult = f.apply(void 0, arguments);
return lastResult;
};
return f2;
}
/**
* Returns the fusable activation given a layers identifier.
*
* @param activationName The layers identifier string.
* @return The name of the fusable activation.
*/
function mapActivationToFusedKernel(activationName) {
if (activationName === 'relu') {
return 'relu';
}
if (activationName === 'linear') {
return 'linear';
}
if (activationName === 'elu') {
return 'elu';
}
return null;
}
/**
* Returns the cartesian product of sets of values.
* This works the same as itertools.product in Python.
*
* Example:
*
* filters = [128, 256, 512]
* paddings = ['same', 'valid']
*
* product = [ [128, 'same'], [128, 'valid'], [256, 'same'], [256, 'valid'],
* [512, 'same'], [512, 'valid']]
*
* @param arrayOfValues List/array of values.
* @return The cartesian product.
*/
function getCartesianProductOfValues() {
for (var _len = arguments.length, arrayOfValues = new Array(_len), _key4 = 0; _key4 < _len; _key4++) {
arrayOfValues[_key4] = arguments[_key4];
}
assert$1(arrayOfValues.length > 0, 'arrayOfValues is empty');
for (var _i6 = 0, _arrayOfValues = arrayOfValues; _i6 < _arrayOfValues.length; _i6++) {
var values = _arrayOfValues[_i6];
assert$1(Array.isArray(values), 'one of the values is not an array');
assert$1(values.length > 0, 'one of the values is empty');
}
return arrayOfValues.reduce(function (products, values) {
if (products.length === 0) {
return values.map(function (value) {
return [value];
});
}
return values.map(function (value) {
return products.map(function (prevValue) {
return [].concat(prevValue, [value]);
});
}).reduce(function (flattenedProduct, unflattenedProduct) {
return flattenedProduct.concat(unflattenedProduct);
}, []);
}, []);
}
/**
* Helper function used by many of the Constraints to find the L2Norms.
*/
function calcL2Norms(w, axis) {
return tidy(function () {
return sqrt$3(sum$1(mul(w, w), axis, true));
});
}
/**
* Base class for functions that impose constraints on weight values
*
* @doc {
* heading: 'Constraints',
* subheading: 'Classes',
* namespace: 'constraints'
* }
*/
var Constraint = /*#__PURE__*/function (_serialization$Serial) {
_inheritsLoose(Constraint, _serialization$Serial);
function Constraint() {
return _serialization$Serial.apply(this, arguments) || this;
}
var _proto = Constraint.prototype;
_proto.getConfig = function getConfig() {
return {};
};
return Constraint;
}(Serializable);
var MaxNorm = /*#__PURE__*/function (_Constraint) {
_inheritsLoose(MaxNorm, _Constraint);
function MaxNorm(args) {
var _this;
_this = _Constraint.call(this) || this;
_this.defaultMaxValue = 2;
_this.defaultAxis = 0;
_this.maxValue = args.maxValue != null ? args.maxValue : _this.defaultMaxValue;
_this.axis = args.axis != null ? args.axis : _this.defaultAxis;
return _this;
}
var _proto2 = MaxNorm.prototype;
_proto2.apply = function apply(w) {
var _this2 = this;
return tidy(function () {
var norms = calcL2Norms(w, _this2.axis);
var desired = clipByValue(norms, 0, _this2.maxValue);
return mul(w, div(desired, add$1(epsilon(), norms)));
});
};
_proto2.getConfig = function getConfig() {
return {
maxValue: this.maxValue,
axis: this.axis
};
};
return MaxNorm;
}(Constraint);
/** @nocollapse */
MaxNorm.className = 'MaxNorm';
registerClass(MaxNorm);
var UnitNorm = /*#__PURE__*/function (_Constraint2) {
_inheritsLoose(UnitNorm, _Constraint2);
function UnitNorm(args) {
var _this3;
_this3 = _Constraint2.call(this) || this;
_this3.defaultAxis = 0;
_this3.axis = args.axis != null ? args.axis : _this3.defaultAxis;
return _this3;
}
var _proto3 = UnitNorm.prototype;
_proto3.apply = function apply(w) {
var _this4 = this;
return tidy(function () {
return div(w, add$1(epsilon(), calcL2Norms(w, _this4.axis)));
});
};
_proto3.getConfig = function getConfig() {
return {
axis: this.axis
};
};
return UnitNorm;
}(Constraint);
/** @nocollapse */
UnitNorm.className = 'UnitNorm';
registerClass(UnitNorm);
var NonNeg = /*#__PURE__*/function (_Constraint3) {
_inheritsLoose(NonNeg, _Constraint3);
function NonNeg() {
return _Constraint3.apply(this, arguments) || this;
}
var _proto4 = NonNeg.prototype;
_proto4.apply = function apply(w) {
return relu(w);
};
return NonNeg;
}(Constraint);
/** @nocollapse */
NonNeg.className = 'NonNeg';
registerClass(NonNeg);
var MinMaxNorm = /*#__PURE__*/function (_Constraint4) {
_inheritsLoose(MinMaxNorm, _Constraint4);
function MinMaxNorm(args) {
var _this5;
_this5 = _Constraint4.call(this) || this;
_this5.defaultMinValue = 0.0;
_this5.defaultMaxValue = 1.0;
_this5.defaultRate = 1.0;
_this5.defaultAxis = 0;
_this5.minValue = args.minValue != null ? args.minValue : _this5.defaultMinValue;
_this5.maxValue = args.maxValue != null ? args.maxValue : _this5.defaultMaxValue;
_this5.rate = args.rate != null ? args.rate : _this5.defaultRate;
_this5.axis = args.axis != null ? args.axis : _this5.defaultAxis;
return _this5;
}
var _proto5 = MinMaxNorm.prototype;
_proto5.apply = function apply(w) {
var _this6 = this;
return tidy(function () {
var norms = calcL2Norms(w, _this6.axis);
var desired = add$1(mul(_this6.rate, clipByValue(norms, _this6.minValue, _this6.maxValue)), mul(1.0 - _this6.rate, norms));
return mul(w, div(desired, add$1(epsilon(), norms)));
});
};
_proto5.getConfig = function getConfig() {
return {
minValue: this.minValue,
maxValue: this.maxValue,
rate: this.rate,
axis: this.axis
};
};
return MinMaxNorm;
}(Constraint);
/** @nocollapse */
MinMaxNorm.className = 'MinMaxNorm';
registerClass(MinMaxNorm); // Maps the JavaScript-like identifier keys to the corresponding registry
// symbols.
var CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'maxNorm': 'MaxNorm',
'minMaxNorm': 'MinMaxNorm',
'nonNeg': 'NonNeg',
'unitNorm': 'UnitNorm'
};
function serializeConstraint(constraint) {
return serializeKerasObject(constraint);
}
function deserializeConstraint(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
}
function getConstraint(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
var className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ? CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
var config = {
className: className,
config: {}
};
return deserializeConstraint(config);
} else if (identifier instanceof Constraint) {
return identifier;
} else {
return deserializeConstraint(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* MaxNorm weight constraint.
*
* Constrains the weights incident to each hidden unit
* to have a norm less than or equal to a desired value.
*
* References
* - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
* Srivastava, Hinton, et al.
* 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
*
* @doc {heading: 'Constraints',namespace: 'constraints'}
*/
function maxNorm(args) {
return new MaxNorm(args);
}
/**
* Constrains the weights incident to each hidden unit to have unit norm.
*
* @doc {heading: 'Constraints', namespace: 'constraints'}
*/
function unitNorm(args) {
return new UnitNorm(args);
}
/**
* Constains the weight to be non-negative.
*
* @doc {heading: 'Constraints', namespace: 'constraints'}
*/
function nonNeg() {
return new NonNeg();
}
/** @doc {heading: 'Constraints', namespace: 'constraints'} */
function minMaxNorm(config) {
return new MinMaxNorm(config);
}
var exports_constraints = {
__proto__: null,
maxNorm: maxNorm,
unitNorm: unitNorm,
nonNeg: nonNeg,
minMaxNorm: minMaxNorm
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
var VALID_INTERPOLATION_FORMAT_VALUES = ['nearest', 'bilinear'];
var VALID_PADDING_MODE_VALUES = ['valid', 'same', 'causal'];
var VALID_POOL_MODE_VALUES = ['max', 'avg'];
var VALID_BIDIRECTIONAL_MERGE_MODES = ['sum', 'mul', 'concat', 'ave'];
var VALID_SAMPLE_WEIGHT_MODES = ['temporal'];
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// wanting that name so far. This allows enforcing name uniqueness by appending
// an incrementing index, e.g. scope/name, scope/name_1, scope/name_2, etc.
var nameMap = new Map();
function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
function checkInterpolationFormat(value) {
checkStringTypeUnionValue(VALID_INTERPOLATION_FORMAT_VALUES, 'InterpolationFormat', value);
}
function checkPaddingMode(value) {
checkStringTypeUnionValue(VALID_PADDING_MODE_VALUES, 'PaddingMode', value);
}
function checkPoolMode(value) {
checkStringTypeUnionValue(VALID_POOL_MODE_VALUES, 'PoolMode', value);
}
var _nameScopeStack = [];
var _nameScopeDivider = '/';
/**
* Enter namescope, which can be nested.
*/
function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
var val = fn();
_nameScopeStack.pop();
return val;
} catch (e) {
_nameScopeStack.pop();
throw e;
}
}
/**
* Get the current namescope as a flat, concatenated string.
*/
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
} else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
/**
* Get the name a Tensor (or Variable) would have if not uniqueified.
* @param tensorName
* @return Scoped name string.
*/
function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
/**
* Get unique names for Tensors and Variables.
* @param scopedName The fully-qualified name of the Tensor, i.e. as produced by
* `getScopedTensorName()`.
* @return A unique version of the given fully scoped name.
* If this is the first time that the scoped name is seen in this session,
* then the given `scopedName` is returned unaltered. If the same name is
* seen again (producing a collision), an incrementing suffix is added to the
* end of the name, so it takes the form 'scope/name_1', 'scope/name_2', etc.
*/
function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
var index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
var result = scopedName + "_" + index; // Mark the composed name as used in case someone wants
// to call getUniqueTensorName("name_1").
nameMap.set(result, 1);
return result;
} else {
return scopedName;
}
}
var tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
/**
* Determine whether a string is a valid tensor name.
* @param name
* @returns A Boolean indicating whether `name` is a valid tensor name.
*/
function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Determine if a number is an integer.
*/
function isInteger$1(x) {
return x === parseInt(x.toString(), 10);
}
/**
* Calculate the product of an array of numbers.
* @param array The array to calculate the product over.
* @param begin Beginning index, inclusive.
* @param end Ending index, exclusive.
* @return The product.
*/
function arrayProd(array, begin, end) {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
var prod = 1;
for (var i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
/**
* Compute minimum value.
* @param array
* @return minimum value.
*/
function min$a(array) {
// same behavior as tf.min()
if (array.length === 0) {
return Number.NaN;
}
var min = Number.POSITIVE_INFINITY;
for (var i = 0; i < array.length; i++) {
var value = array[i];
if (value < min) {
min = value;
}
}
return min;
}
/**
* Compute maximum value.
* @param array
* @return maximum value
*/
function max$6(array) {
// same behavior as tf.max()
if (array.length === 0) {
return Number.NaN;
}
var max = Number.NEGATIVE_INFINITY;
for (var i = 0; i < array.length; i++) {
var value = array[i];
if (value > max) {
max = value;
}
}
return max;
}
/**
* Compute sum of array.
* @param array
* @return The sum.
*/
function sum$2(array) {
var sum = 0;
for (var i = 0; i < array.length; i++) {
var value = array[i];
sum += value;
}
return sum;
}
/**
* Compute mean of array.
* @param array
* @return The mean.
*/
function mean$2(array) {
return sum$2(array) / array.length;
}
/**
* Compute variance of array.
* @param array
* @return The variance.
*/
function variance(array) {
var meanValue = mean$2(array);
var demeaned = array.map(function (value) {
return value - meanValue;
});
var sumSquare = 0;
for (var i = 0; i < demeaned.length; i++) {
var value = demeaned[i];
sumSquare += value * value;
}
return sumSquare / array.length;
}
/**
* Compute median of array.
* @param array
* @return The median value.
*/
function median(array) {
var arraySorted = array.slice().sort(function (a, b) {
return a - b;
});
var lowIdx = Math.floor((arraySorted.length - 1) / 2);
var highIdx = Math.ceil((arraySorted.length - 1) / 2);
if (lowIdx === highIdx) {
return arraySorted[lowIdx];
}
return (arraySorted[lowIdx] + arraySorted[highIdx]) / 2;
}
/**
* Generate an array of integers in [begin, end).
* @param begin Beginning integer, inclusive.
* @param end Ending integer, exclusive.
* @returns Range array.
* @throws ValueError, iff `end` < `begin`.
*/
function range$1(begin, end) {
if (end < begin) {
throw new ValueError("end (" + end + ") < begin (" + begin + ") is forbidden.");
}
var out = [];
for (var i = begin; i < end; ++i) {
out.push(i);
}
return out;
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* Setting and getting backend from deeplearn.js. */
// Default deeplearn.js backend is WebGL (GPU).
var backend$1 = 'webgl';
function setBackend$1(requestedBackend) {
setBackend(requestedBackend);
backend$1 = requestedBackend;
}
function getBackend$1() {
return backend$1;
}
/**
* Indicates whether the backend is operating symbolically.
*
* This function will be used to determine how to interpret user code. If
* it returns true, calls to the backend construct a symbolic graph; if
* it returns false, calls to the backend execute immediately.
*/
function isBackendSymbolic() {
return false;
}
/**
* Get the number of elements in a Tensor.
* @param x The Tensor.
* @return Number of elements in `x`.
*/
function countParams(x) {
var shape = x.shape;
if (shape.length > 0) {
return shape.reduce(function (a, b) {
return a * b;
});
} else {
// Scalar.
return 1;
}
}
/**
* Casts a tensor to a different dtype and returns it.
* @param x Input tensor.
* @param dtype String: 'float32'|'int32'|'bool'.
* @returns Tensor of the specified `dtype`.
*/
function cast$1(x, dtype) {
return cast(x, dtype);
}
/**
* Adds a 1-sized dimension at index "axis".
* @param x Input tensor.
* @param axis Position where to add the new axis.
* @returns Result of the dimension expansion.
*/
function expandDims$1(x, axis) {
if (axis === void 0) {
axis = -1;
}
var outShape = x.shape.slice();
if (axis < 0) {
axis = outShape.length + axis + 1;
}
outShape.splice(axis, 0, 1);
return reshape(x, outShape);
}
/**
* Repeats a 2D tensor.
*
* If `x` has shape `[samples, dim]` and `n` is 2, for example, the output
* will have shape `[samples, 2, dim]`.
*
* @param x Input tensor.
* @param n Integer, number of times to repeat.
* @returns The result of the repeat operation.
* @throws ValueError: If input tensor is not 2D.
*/
function repeat(x, n) {
return tidy(function () {
if (x.shape.length !== 2) {
throw new ValueError("repeat() expects a rank-2 tensor, but received a " + ("rank-" + x.shape.length + " tensor."));
}
var y = expandDims$1(x, 1);
return tile$1(y, [1, n, 1]);
});
}
/**
* Flatten a Tensor into 1D.
* @param x Input tensor.
* @return The result of the flattening `x`.
*/
function flatten$1(x) {
var newShape = [arrayProd(x.shape)];
return reshape(x, newShape);
}
/**
* Turn a nD tensor into a 2D tensor with same 0th dimension.
* In other words, it flattens each data samples of a batch.
*
* @param x The tensor to flatten. The rank of this tensor is required to be 2
* or higher.
* @return The result of the flattening.
*/
function batchFlatten(x) {
if (x.rank <= 1) {
throw new ValueError("batchFlatten requires a minimum rank of 2. Got rank: " + x.rank + ".");
}
var newShape = [x.shape[0], arrayProd(x.shape, 1)];
return reshape(x, newShape);
}
/**
* Do slicing along the first axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the first axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongFirstAxis(array, start, size) {
return tidy(function () {
switch (array.rank) {
case 1:
return slice1d(array, start, size);
case 2:
return slice2d(array, [start, 0], [size, array.shape[1]]);
case 3:
return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
case 4:
return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
case 5:
return slice$2(array, [start, 0, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]]);
case 6:
return slice$2(array, [start, 0, 0, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3], array.shape[4], array.shape[5]]);
default:
throw new ValueError("sliceAlongFirstAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
/**
* Do slicing along the last axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size size of the slice along the last axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongLastAxis(array, start, size) {
return tidy(function () {
switch (array.rank) {
case 1:
return slice1d(array, start, size);
case 2:
return slice2d(array, [0, start], [array.shape[0], size]);
case 3:
return slice3d(array, [0, 0, start], [array.shape[0], array.shape[1], size]);
case 4:
return slice4d(array, [0, 0, 0, start], [array.shape[0], array.shape[1], array.shape[2], size]);
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
/**
* Do slicing along the sepcified axis.
* @param array input `tf.Tensor`.
* @param start starting index, inclusive.
* @param size of the slice along the chosen axis.
* @param choose an axis.
* @returns result of the slicing.
* @throws ValueError: If `array` is of an unsupported subtype of `tf.Tensor`.
*/
function sliceAlongAxis(array, start, size, axis) {
return tidy(function () {
switch (array.rank) {
case 1:
return slice1d(array, start, size);
case 2:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
case 3:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return slice3d(array, [0, start, 0], [array.shape[0], size, array.shape[2]]);
case 3:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
case 4:
switch (axis) {
case 1:
return sliceAlongFirstAxis(array, start, size);
case 2:
return slice4d(array, [0, start, 0, 0], [array.shape[0], size, array.shape[2], array.shape[3]]);
case 3:
return slice4d(array, [0, 0, start, 0], [array.shape[0], array.shape[1], size, array.shape[3]]);
case 4:
return sliceAlongLastAxis(array, start, size);
default:
throw new ValueError("The axis is not within the rank of the tensor " + ("" + axis));
}
default:
throw new ValueError("sliceAlongLastAxis() received an unsupported tensor rank: " + ("" + array.rank));
}
});
}
/**
* Concatenates a list of tensors alongside the specified axis.
* @param tensors `Array` of tensors to concatenate.
* @param axis Concatenation axis.
* @returns The result of the concatenation.
*/
function concatenate(tensors, axis) {
if (axis === void 0) {
axis = -1;
}
var rank;
if (axis < 0) {
rank = tensors[0].rank;
if (rank !== 0) {
axis = rank;
} else {
axis = 0;
}
}
if (axis === tensors[0].rank) {
// Porting Note: This is necessary because tfc.concat() requires axis to be
// in the interval [-rank, rank).
axis = -1;
} // Porting Note: Sparse concat is not supported yet.
return concat(tensors, axis);
}
/**
* Concatenate two arrays along the first dimension.
* @param a The 1st `tf.Tensor` to concatenate.
* @param b The 2nd `tf.Tensor` to concatenate.
* @returns Result of the concatenation.
* @throws ValueError: If `a` is of an unsupported subtype of `tf.Tensor`.
*/
function concatAlongFirstAxis(a, b) {
switch (a.rank) {
case 1:
return concat1d([a, b]);
case 2:
return concat2d([a, b], 0);
case 3:
return concat3d([a, b], 0);
case 4:
return concat4d([a, b], 0);
default:
throw new ValueError("concatAlongFirstAxis() received an unsupported " + ("tensor rank: " + a.rank));
}
}
/**
* Creates a tensor by tiling `x` by `n`.
* @param x A tensor.
* @param n An Array of integers or a single integer. If an Array, the length
* must be the same as the number of dimensions in `x`. If a single integer,
* it will be treated as an Array of length 1.
*/
function tile$1(x, n) {
if (!Array.isArray(n)) {
n = [n];
}
if (x.rank !== n.length) {
throw new ValueError("The length of input n (" + n.length + ") does not match " + ("the number of dimensions in input x (" + x.rank + ")"));
}
return tile(x, n);
}
/* Creation of random tensors. */
/**
* Get a tensor with normal distribution of values.
*
* @param shape Shape of the tensor.
* @param mean mean value of the normal distribution.
* @param stddev standard deviation of the normal distribution.
* @param dtype
* @param seed
* @return The normal tensor.
*/
function randomNormal$1(shape, mean, stddev, dtype, seed) {
if (mean === void 0) {
mean = 0.0;
}
if (stddev === void 0) {
stddev = 1.0;
}
return randomNormal(shape, mean, stddev, dtype, seed);
}
/* Linear Algebra */
/**
* Multiply two tensors and returns the result as a tensor.
*
* For 2D tensors, this is equivalent to matrix multiplication (matMul).
* For tensors of higher ranks, it follows the Theano behavior,
* (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`). From the Theano documentation:
*
* For N dimensions it is a sum product over the last axis of x and the
* second-to-last of y:
*
* @param a A tensor of at least rank 2.
* @param b A tensor of at least rank 2.
* @param activation (optional) A string identifying the activation
* function.
* @return Result of the dot operation.
*/
function dot$1(a, b, activation, bias) {
if (a.rank < 2 || b.rank < 2) {
throw new NotImplementedError("dot requires both inputs to be rank >= 2" + (" but got x shape = " + a.shape + " and y shape = " + b.shape));
}
if (b.rank >= 3) {
var xLastDim = a.shape.slice(-1)[0];
var ySecondLastDim = b.shape.slice(-2)[0];
if (xLastDim !== ySecondLastDim) {
throw new NotImplementedError("If rank y >= 3, then the second last dim" + (" of y must equal the last dim of x but got x shape = " + a.shape + " and ") + (" y shape = " + b.shape));
}
} // Handle basic 2D x 2D case.
if (a.rank === 2 && b.rank === 2) {
var transposeA = false;
var transposeB = false; // tfc.fused.matMul only fuses certain activation functions. Unsupported
// activation functions are treated as 'linear' activations, which is
// equivalent to a no-op.
return matMul$1({
a: a,
b: b,
transposeA: transposeA,
transposeB: transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation: activation
});
} else {
// Reshape x into the analogous 2D Tensor.
var aFirstDims = a.shape.slice(); // Holds all but the last dim of x.
var aLastDim = aFirstDims.pop();
a = reshape(a, [-1, aLastDim]); // Reshape y into the analogous 2D Tensor, and keep track of the
// required dimensions to reproduce the output shape.
var bShape = b.shape.slice();
var bLastDim = bShape.pop();
var _ySecondLastDim = bShape.pop();
var yOtherDims = [].concat(bShape, [bLastDim]); // permutation should be like [r-2, 0, 1, 2, ... r-4, r-3, r-1]
// where r is the rank of y.
var perm = Array.from({
length: b.rank
}, function (_, i) {
if (i === 0) {
return b.rank - 2;
} else if (i <= b.rank - 2) {
return i - 1;
}
return i;
});
b = reshape(transpose(b, perm), [_ySecondLastDim, -1]); // Multiply x and y as 2D Tensors, and then reshape back to original.
var outputShape = [].concat(aFirstDims, yOtherDims);
var _transposeA = false;
var _transposeB = false;
return reshape(matMul$1({
a: a,
b: b,
transposeA: _transposeA,
transposeB: _transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation: activation
}), outputShape);
}
}
/**
* Compute the sign Tensor of an input Tensor.
*
* Elements of the input `tf.Tensor` that are === 0 are mapped to 0.
* Elements of the input `tf.Tensor` that are > 0 are mapped to 1.
* Elements of the input `tf.Tensor` that are < 0 are mapped to -1.
*
* @param x Input `tf.Tensor`.
* @return The sign `tf.Tensor`.
*/
function sign$1(x) {
// TODO(cais): Move to the core.
return tidy(function () {
var zerosLikeX = zerosLike(x);
var onesLikeX = onesLike(x);
return where(equal(x, zerosLikeX), zerosLikeX, where(greater(x, zerosLike(x)), onesLikeX, mul(-1, onesLikeX)));
});
}
/**
* Computes the one-hot representation of an integer tensor.
* @param indices nD integer tensor of shape
* `(batch_size, dim1, dim2, ... dim(n-1))`
* @param numClasses Integer, number of classes to consider.
* @returns (n + 1)D one hot representation of the input
* with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
*/
function oneHot$1(indices, numClasses) {
return tidy(function () {
if (indices.rank !== 1) {
throw new Error('Only 1D one-hot tensors are supported in the ' + 'deeplearn backend, at present.');
}
indices = cast(indices, 'int32');
return cast(oneHot(indices, numClasses), 'float32');
});
}
/* Elementary math functions. */
/**
* Retrieves the elements of indices `indices` in the tensor `reference`.
* @param reference A tensor.
* @param indices An integer tensor of indices or an `Array` of integers.
* @param axis Axis along which to perform the gather operation.
* @returns The result of the gathering as a tensor.
*/
function gather$1(reference, indices, axis) {
return tidy(function () {
if (Array.isArray(indices)) {
indices = tensor1d(indices, 'int32');
} else {
indices = cast(indices, 'int32');
}
return gather(reference, indices, axis);
});
}
/**
* Element-wise square.
* @param x Input tensor.
* @return element-wise x^2
*/
function square$1(x) {
return mul(x, x);
}
/**
* Element-wise exponentiation.
*
* Porting Note: In PyKeras, `a` (the exponent) is a Python integer, which
* takes advatnage of the backend's (e.g., TensorFlow's) automatic
* conversion to tensor. Here we allow `a` to be either a number or a tensor.
*
* @param x The base tensor.
* @param a The exponent, tensor or number. If a number, it is rounded to the
* nearest integer and converted to a tensor.
* @returns A tensor of the same shape as `x`.
*/
function pow$6(x, a) {
return tidy(function () {
if (typeof a === 'number') {
a = scalar(Math.round(a), 'int32');
}
if (a.dtype !== 'int32') {
throw new NotImplementedError("Non-int32 dtype (" + a.dtype + ") is not supported by pow() yet");
}
return pow$5(x, a);
});
}
/**
* Reshapes bias tensor according to rank of x.
*/
function reshapeBias(xRank, bias, dataFormat) {
var biasShape = bias.shape;
if (bias.rank !== 1 && bias.rank !== xRank) {
throw new ValueError("Unexpected bias dimensions: " + bias.rank + ("; expected it to be 1 or " + xRank));
}
if (xRank === 5) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape(bias, [1, biasShape[0], 1, 1, 1]);
} else {
return reshape(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
}
} else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape(bias, [1, 1, 1, 1, biasShape[0]]);
} else {
return reshape(bias, [1].concat(biasShape));
}
}
} else if (xRank === 4) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape(bias, [1, biasShape[0], 1, 1]);
} else {
return reshape(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
}
} else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape(bias, [1, 1, 1, biasShape[0]]);
} else {
return reshape(bias, [1].concat(biasShape));
}
}
} else if (xRank === 3) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape(bias, [1, biasShape[0], 1]);
} else {
return reshape(bias, [1, biasShape[1], biasShape[0]]);
}
} else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape(bias, [1, 1, biasShape[0]]);
} else {
return reshape(bias, [1].concat(biasShape));
}
}
} else if (xRank < 3) {
return bias;
}
throw new ValueError("Unsupported input rank by biasAdd: " + bias.rank);
}
/* Neural-network operations. */
/**
* Add a bias to a tensor.
*
* @param x The tensor to add the bias to.
* @param bias The bias to add to `x`. Must be 1D or the same rank as `x`.
* @return Result of the bias adding.
* @throws ValueError: If the rank of `bias` is incorrect.
*/
function biasAdd(x, bias, dataFormat) {
return tidy(function () {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
return add$1(x, reshapeBias(x.rank, bias, dataFormat));
});
}
/**
* Exponential linear unit (ELU).
* @param x A tensor or variable to compute the activation function for.
* @param alpha: A scalar, a scaling factor for the negative section.
* @return Output of the ELU operation.
*/
function elu$1(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
// TODO(cais): Add support for alpha values other than 1.
if (alpha !== 1) {
throw new NotImplementedError("Support for alpha values other than 1 (" + alpha + ") is not implemented " + "yet.");
}
return elu(x);
}
/**
* Softsign of a tensor.
*
* Defined as x / (abs(x) + 1), element-wise.
*
* @param x: Input.
* @returns Output.
*/
function softsign(x) {
return tidy(function () {
return div(x, add$1(abs$8(x), 1));
});
}
/**
* Sets entries in `x` to zero at random, while scaling the entire tensor.
*
* @param x input tensor.
* @param level fraction of the entries in the tensor that will be set to 0.
* @param noiseShape shape of randomly generated keep/drop flags, must be
* broadcastable to the shape of `x`. Optional.
* @param seed random seed to ensure determinism. Optional.
* @returns Result of the dropout operation.
*/
function dropout$1(x, level, noiseShape, seed) {
return tidy(function () {
return dropout(x, level, noiseShape, seed);
});
}
/**
* Element-wise, segment-wise linear approximation of sigmoid.
*
* Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
* In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
*
* @param x Input tensor.
* @returns Output tensor.
*/
function hardSigmoid(x) {
return tidy(function () {
var y = add$1(.5, mul(.2, x));
return clipByValue(y, 0, 1);
});
}
/**
* Invoke `x` in the training phase, and `alt` otherwise.
*
* Porting Note: We do not create placeholder tensors for the `training`
* boolean flag here, because there is no such thing in the TF.js imperative
* backend.
*
* @param x The function to invoke iff `training` is `true`.
* @param alt The function to invoke iff `training` is `false`.
* @param training Boolean flag for whether training phase is active.
* @returns The return value of `x()` if `training` is `true`, or the return
* value of `alt()` if `training` is `false`.
*/
function inTrainPhase(x, alt, training) {
if (training === void 0) {
training = false;
}
return training ? x() : alt();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
var VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal']; // We can't easily extract a string[] from the string union type, but we can
// recapitulate the list, enforcing at compile time that the values are valid
// and that we have the right number of them.
/**
* A string array of valid Initializer class names.
*
* This is guaranteed to match the `InitializerClassName` union type.
*/
var initializerClassNames = ['Zeros', 'Ones', 'Constant', 'RandomNormal', 'RandomUniform', 'TruncatedNormal', 'VarianceScaling', 'Orthogonal', 'Identity'];
function checkFanMode(value) {
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
}
function checkDistribution(value) {
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
}
/**
* Initializer base class.
*
* @doc {
* heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
*/
var Initializer = /*#__PURE__*/function (_serialization$Serial) {
_inheritsLoose(Initializer, _serialization$Serial);
function Initializer() {
return _serialization$Serial.apply(this, arguments) || this;
}
var _proto = Initializer.prototype;
_proto.fromConfigUsesCustomObjects = function fromConfigUsesCustomObjects() {
return false;
};
_proto.getConfig = function getConfig() {
return {};
};
return Initializer;
}(Serializable);
var Zeros = /*#__PURE__*/function (_Initializer) {
_inheritsLoose(Zeros, _Initializer);
function Zeros() {
return _Initializer.apply(this, arguments) || this;
}
var _proto2 = Zeros.prototype;
_proto2.apply = function apply(shape, dtype) {
return zeros(shape, dtype);
};
return Zeros;
}(Initializer);
/** @nocollapse */
Zeros.className = 'Zeros';
registerClass(Zeros);
var Ones = /*#__PURE__*/function (_Initializer2) {
_inheritsLoose(Ones, _Initializer2);
function Ones() {
return _Initializer2.apply(this, arguments) || this;
}
var _proto3 = Ones.prototype;
_proto3.apply = function apply(shape, dtype) {
return ones$1(shape, dtype);
};
return Ones;
}(Initializer);
/** @nocollapse */
Ones.className = 'Ones';
registerClass(Ones);
var Constant = /*#__PURE__*/function (_Initializer3) {
_inheritsLoose(Constant, _Initializer3);
function Constant(args) {
var _this;
_this = _Initializer3.call(this) || this;
if (typeof args !== 'object') {
throw new ValueError("Expected argument of type ConstantConfig but got " + args);
}
if (args.value === undefined) {
throw new ValueError("config must have value set but got " + args);
}
_this.value = args.value;
return _this;
}
var _proto4 = Constant.prototype;
_proto4.apply = function apply(shape, dtype) {
var _this2 = this;
return tidy(function () {
return mul(scalar(_this2.value), ones$1(shape, dtype));
});
};
_proto4.getConfig = function getConfig() {
return {
value: this.value
};
};
return Constant;
}(Initializer);
/** @nocollapse */
Constant.className = 'Constant';
registerClass(Constant);
var RandomUniform = /*#__PURE__*/function (_Initializer4) {
_inheritsLoose(RandomUniform, _Initializer4);
function RandomUniform(args) {
var _this3;
_this3 = _Initializer4.call(this) || this;
_this3.DEFAULT_MINVAL = -0.05;
_this3.DEFAULT_MAXVAL = 0.05;
_this3.minval = args.minval || _this3.DEFAULT_MINVAL;
_this3.maxval = args.maxval || _this3.DEFAULT_MAXVAL;
_this3.seed = args.seed;
return _this3;
}
var _proto5 = RandomUniform.prototype;
_proto5.apply = function apply(shape, dtype) {
return randomUniform(shape, this.minval, this.maxval, dtype);
};
_proto5.getConfig = function getConfig() {
return {
minval: this.minval,
maxval: this.maxval,
seed: this.seed
};
};
return RandomUniform;
}(Initializer);
/** @nocollapse */
RandomUniform.className = 'RandomUniform';
registerClass(RandomUniform);
var RandomNormal = /*#__PURE__*/function (_Initializer5) {
_inheritsLoose(RandomNormal, _Initializer5);
function RandomNormal(args) {
var _this4;
_this4 = _Initializer5.call(this) || this;
_this4.DEFAULT_MEAN = 0.;
_this4.DEFAULT_STDDEV = 0.05;
_this4.mean = args.mean || _this4.DEFAULT_MEAN;
_this4.stddev = args.stddev || _this4.DEFAULT_STDDEV;
_this4.seed = args.seed;
return _this4;
}
var _proto6 = RandomNormal.prototype;
_proto6.apply = function apply(shape, dtype) {
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError("randomNormal does not support dType " + dtype + ".");
}
return randomNormal$1(shape, this.mean, this.stddev, dtype, this.seed);
};
_proto6.getConfig = function getConfig() {
return {
mean: this.mean,
stddev: this.stddev,
seed: this.seed
};
};
return RandomNormal;
}(Initializer);
/** @nocollapse */
RandomNormal.className = 'RandomNormal';
registerClass(RandomNormal);
var TruncatedNormal = /*#__PURE__*/function (_Initializer6) {
_inheritsLoose(TruncatedNormal, _Initializer6);
function TruncatedNormal(args) {
var _this5;
_this5 = _Initializer6.call(this) || this;
_this5.DEFAULT_MEAN = 0.;
_this5.DEFAULT_STDDEV = 0.05;
_this5.mean = args.mean || _this5.DEFAULT_MEAN;
_this5.stddev = args.stddev || _this5.DEFAULT_STDDEV;
_this5.seed = args.seed;
return _this5;
}
var _proto7 = TruncatedNormal.prototype;
_proto7.apply = function apply(shape, dtype) {
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError("truncatedNormal does not support dType " + dtype + ".");
}
return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
};
_proto7.getConfig = function getConfig() {
return {
mean: this.mean,
stddev: this.stddev,
seed: this.seed
};
};
return TruncatedNormal;
}(Initializer);
/** @nocollapse */
TruncatedNormal.className = 'TruncatedNormal';
registerClass(TruncatedNormal);
var Identity$1 = /*#__PURE__*/function (_Initializer7) {
_inheritsLoose(Identity, _Initializer7);
function Identity(args) {
var _this6;
_this6 = _Initializer7.call(this) || this;
_this6.gain = args.gain != null ? args.gain : 1.0;
return _this6;
}
var _proto8 = Identity.prototype;
_proto8.apply = function apply(shape, dtype) {
var _this7 = this;
return tidy(function () {
if (shape.length !== 2 || shape[0] !== shape[1]) {
throw new ValueError('Identity matrix initializer can only be used for' + ' 2D square matrices.');
} else {
return mul(_this7.gain, eye(shape[0]));
}
});
};
_proto8.getConfig = function getConfig() {
return {
gain: this.gain
};
};
return Identity;
}(Initializer);
/** @nocollapse */
Identity$1.className = 'Identity';
registerClass(Identity$1);
/**
* Computes the number of input and output units for a weight shape.
* @param shape Shape of weight.
* @param dataFormat data format to use for convolution kernels.
* Note that all kernels in Keras are standardized on the
* CHANNEL_LAST ordering (even when inputs are set to CHANNEL_FIRST).
* @return An length-2 array: fanIn, fanOut.
*/
function computeFans(shape, dataFormat) {
if (dataFormat === void 0) {
dataFormat = 'channelsLast';
}
var fanIn;
var fanOut;
checkDataFormat(dataFormat);
if (shape.length === 2) {
fanIn = shape[0];
fanOut = shape[1];
} else if ([3, 4, 5].indexOf(shape.length) !== -1) {
if (dataFormat === 'channelsFirst') {
var receptiveFieldSize = arrayProd(shape, 2);
fanIn = shape[1] * receptiveFieldSize;
fanOut = shape[0] * receptiveFieldSize;
} else if (dataFormat === 'channelsLast') {
var _receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
fanIn = shape[shape.length - 2] * _receptiveFieldSize;
fanOut = shape[shape.length - 1] * _receptiveFieldSize;
}
} else {
var shapeProd = arrayProd(shape);
fanIn = Math.sqrt(shapeProd);
fanOut = Math.sqrt(shapeProd);
}
return [fanIn, fanOut];
}
var VarianceScaling = /*#__PURE__*/function (_Initializer8) {
_inheritsLoose(VarianceScaling, _Initializer8);
/**
* Constructor of VarianceScaling.
* @throws ValueError for invalid value in scale.
*/
function VarianceScaling(args) {
var _this8;
_this8 = _Initializer8.call(this) || this;
if (args.scale < 0.0) {
throw new ValueError("scale must be a positive float. Got: " + args.scale);
}
_this8.scale = args.scale == null ? 1.0 : args.scale;
_this8.mode = args.mode == null ? 'fanIn' : args.mode;
checkFanMode(_this8.mode);
_this8.distribution = args.distribution == null ? 'normal' : args.distribution;
checkDistribution(_this8.distribution);
_this8.seed = args.seed;
return _this8;
}
var _proto9 = VarianceScaling.prototype;
_proto9.apply = function apply(shape, dtype) {
var fans = computeFans(shape);
var fanIn = fans[0];
var fanOut = fans[1];
var scale = this.scale;
if (this.mode === 'fanIn') {
scale /= Math.max(1, fanIn);
} else if (this.mode === 'fanOut') {
scale /= Math.max(1, fanOut);
} else {
scale /= Math.max(1, (fanIn + fanOut) / 2);
}
if (this.distribution === 'normal') {
var stddev = Math.sqrt(scale);
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError(this.getClassName() + " does not support dType " + dtype + ".");
}
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
} else {
var limit = Math.sqrt(3 * scale);
return randomUniform(shape, -limit, limit, dtype);
}
};
_proto9.getConfig = function getConfig() {
return {
scale: this.scale,
mode: this.mode,
distribution: this.distribution,
seed: this.seed
};
};
return VarianceScaling;
}(Initializer);
/** @nocollapse */
VarianceScaling.className = 'VarianceScaling';
registerClass(VarianceScaling);
var GlorotUniform = /*#__PURE__*/function (_VarianceScaling) {
_inheritsLoose(GlorotUniform, _VarianceScaling);
/**
* Constructor of GlorotUniform
* @param scale
* @param mode
* @param distribution
* @param seed
*/
function GlorotUniform(args) {
return _VarianceScaling.call(this, {
scale: 1.0,
mode: 'fanAvg',
distribution: 'uniform',
seed: args == null ? null : args.seed
}) || this;
}
var _proto10 = GlorotUniform.prototype;
_proto10.getClassName = function getClassName() {
// In Python Keras, GlorotUniform is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return GlorotUniform;
}(VarianceScaling);
/** @nocollapse */
GlorotUniform.className = 'GlorotUniform';
registerClass(GlorotUniform);
var GlorotNormal = /*#__PURE__*/function (_VarianceScaling2) {
_inheritsLoose(GlorotNormal, _VarianceScaling2);
/**
* Constructor of GlorotNormal.
* @param scale
* @param mode
* @param distribution
* @param seed
*/
function GlorotNormal(args) {
return _VarianceScaling2.call(this, {
scale: 1.0,
mode: 'fanAvg',
distribution: 'normal',
seed: args == null ? null : args.seed
}) || this;
}
var _proto11 = GlorotNormal.prototype;
_proto11.getClassName = function getClassName() {
// In Python Keras, GlorotNormal is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return GlorotNormal;
}(VarianceScaling);
/** @nocollapse */
GlorotNormal.className = 'GlorotNormal';
registerClass(GlorotNormal);
var HeNormal = /*#__PURE__*/function (_VarianceScaling3) {
_inheritsLoose(HeNormal, _VarianceScaling3);
function HeNormal(args) {
return _VarianceScaling3.call(this, {
scale: 2.0,
mode: 'fanIn',
distribution: 'normal',
seed: args == null ? null : args.seed
}) || this;
}
var _proto12 = HeNormal.prototype;
_proto12.getClassName = function getClassName() {
// In Python Keras, HeNormal is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return HeNormal;
}(VarianceScaling);
/** @nocollapse */
HeNormal.className = 'HeNormal';
registerClass(HeNormal);
var HeUniform = /*#__PURE__*/function (_VarianceScaling4) {
_inheritsLoose(HeUniform, _VarianceScaling4);
function HeUniform(args) {
return _VarianceScaling4.call(this, {
scale: 2.0,
mode: 'fanIn',
distribution: 'uniform',
seed: args == null ? null : args.seed
}) || this;
}
var _proto13 = HeUniform.prototype;
_proto13.getClassName = function getClassName() {
// In Python Keras, HeUniform is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return HeUniform;
}(VarianceScaling);
/** @nocollapse */
HeUniform.className = 'HeUniform';
registerClass(HeUniform);
var LeCunNormal = /*#__PURE__*/function (_VarianceScaling5) {
_inheritsLoose(LeCunNormal, _VarianceScaling5);
function LeCunNormal(args) {
return _VarianceScaling5.call(this, {
scale: 1.0,
mode: 'fanIn',
distribution: 'normal',
seed: args == null ? null : args.seed
}) || this;
}
var _proto14 = LeCunNormal.prototype;
_proto14.getClassName = function getClassName() {
// In Python Keras, LeCunNormal is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return LeCunNormal;
}(VarianceScaling);
/** @nocollapse */
LeCunNormal.className = 'LeCunNormal';
registerClass(LeCunNormal);
var LeCunUniform = /*#__PURE__*/function (_VarianceScaling6) {
_inheritsLoose(LeCunUniform, _VarianceScaling6);
function LeCunUniform(args) {
return _VarianceScaling6.call(this, {
scale: 1.0,
mode: 'fanIn',
distribution: 'uniform',
seed: args == null ? null : args.seed
}) || this;
}
var _proto15 = LeCunUniform.prototype;
_proto15.getClassName = function getClassName() {
// In Python Keras, LeCunUniform is not a class, but a helper method
// that creates a VarianceScaling object. Use 'VarianceScaling' as
// class name to be compatible with that.
return VarianceScaling.className;
};
return LeCunUniform;
}(VarianceScaling);
/** @nocollapse */
LeCunUniform.className = 'LeCunNormal';
registerClass(LeCunUniform);
var Orthogonal = /*#__PURE__*/function (_Initializer9) {
_inheritsLoose(Orthogonal, _Initializer9);
function Orthogonal(args) {
var _this9;
_this9 = _Initializer9.call(this) || this;
_this9.DEFAULT_GAIN = 1;
_this9.gain = args.gain == null ? _this9.DEFAULT_GAIN : args.gain;
_this9.seed = args.seed;
if (_this9.seed != null) {
throw new NotImplementedError('Random seed is not implemented for Orthogonal Initializer yet.');
}
return _this9;
}
var _proto16 = Orthogonal.prototype;
_proto16.apply = function apply(shape, dtype) {
var _this10 = this;
return tidy(function () {
if (shape.length < 2) {
throw new NotImplementedError('Shape must be at least 2D.');
}
if (shape[0] * shape[1] > 2000) {
console.warn("Orthogonal initializer is being called on a matrix with more " + ("than 2000 (" + shape[0] * shape[1] + ") elements: ") + "Slowness may result.");
} // TODO(cais): Add seed support.
var normalizedShape = shape[0] > shape[1] ? [shape[1], shape[0]] : shape;
var a = randomNormal$1(normalizedShape, 0, 1, 'float32');
var q = linalg.gramSchmidt(a);
if (shape[0] > shape[1]) {
q = transpose(q);
}
return mul(_this10.gain, q);
});
};
_proto16.getConfig = function getConfig() {
return {
gain: this.gain,
seed: this.seed
};
};
return Orthogonal;
}(Initializer);
/** @nocollapse */
Orthogonal.className = 'Orthogonal';
registerClass(Orthogonal); // Maps the JavaScript-like identifier keys to the corresponding registry
// symbols.
var INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'constant': 'Constant',
'glorotNormal': 'GlorotNormal',
'glorotUniform': 'GlorotUniform',
'heNormal': 'HeNormal',
'heUniform': 'HeUniform',
'identity': 'Identity',
'leCunNormal': 'LeCunNormal',
'leCunUniform': 'LeCunUniform',
'ones': 'Ones',
'orthogonal': 'Orthogonal',
'randomNormal': 'RandomNormal',
'randomUniform': 'RandomUniform',
'truncatedNormal': 'TruncatedNormal',
'varianceScaling': 'VarianceScaling',
'zeros': 'Zeros'
};
function deserializeInitializer(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
}
function serializeInitializer(initializer) {
return serializeKerasObject(initializer);
}
function getInitializer(identifier) {
if (typeof identifier === 'string') {
var className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
/* We have four 'helper' classes for common initializers that
all get serialized as 'VarianceScaling' and shouldn't go through
the deserializeInitializer pathway. */
if (className === 'GlorotNormal') {
return new GlorotNormal();
} else if (className === 'GlorotUniform') {
return new GlorotUniform();
} else if (className === 'HeNormal') {
return new HeNormal();
} else if (className === 'HeUniform') {
return new HeUniform();
} else if (className === 'LeCunNormal') {
return new LeCunNormal();
} else if (className === 'LeCunUniform') {
return new LeCunUniform();
} else {
var config = {};
config['className'] = className;
config['config'] = {};
return deserializeInitializer(config);
}
} else if (identifier instanceof Initializer) {
return identifier;
} else {
return deserializeInitializer(identifier);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Initializer that generates tensors initialized to 0.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function zeros$1() {
return new Zeros();
}
/**
* Initializer that generates tensors initialized to 1.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function ones$2() {
return new Ones();
}
/**
* Initializer that generates values initialized to some constant.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function constant(args) {
return new Constant(args);
}
/**
* Initializer that generates random values initialized to a uniform
* distribution.
*
* Values will be distributed uniformly between the configured minval and
* maxval.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function randomUniform$1(args) {
return new RandomUniform(args);
}
/**
* Initializer that generates random values initialized to a normal
* distribution.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function randomNormal$2(args) {
return new RandomNormal(args);
}
/**
* Initializer that generates random values initialized to a truncated normal.
* distribution.
*
* These values are similar to values from a `RandomNormal` except that values
* more than two standard deviations from the mean are discarded and re-drawn.
* This is the recommended initializer for neural network weights and filters.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function truncatedNormal$1(args) {
return new TruncatedNormal(args);
}
/**
* Initializer that generates the identity matrix.
* Only use for square 2D matrices.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function identity(args) {
return new Identity$1(args);
}
/**
* Initializer capable of adapting its scale to the shape of weights.
* With distribution=NORMAL, samples are drawn from a truncated normal
* distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
* - number of input units in the weight tensor, if mode = FAN_IN.
* - number of output units, if mode = FAN_OUT.
* - average of the numbers of input and output units, if mode = FAN_AVG.
* With distribution=UNIFORM,
* samples are drawn from a uniform distribution
* within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
*
* @doc {heading: 'Initializers',namespace: 'initializers'}
*/
function varianceScaling(config) {
return new VarianceScaling(config);
}
/**
* Glorot uniform initializer, also called Xavier uniform initializer.
* It draws samples from a uniform distribution within [-limit, limit]
* where `limit` is `sqrt(6 / (fan_in + fan_out))`
* where `fan_in` is the number of input units in the weight tensor
* and `fan_out` is the number of output units in the weight tensor
*
* Reference:
* Glorot & Bengio, AISTATS 2010
* http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function glorotUniform(args) {
return new GlorotUniform(args);
}
/**
* Glorot normal initializer, also called Xavier normal initializer.
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(2 / (fan_in + fan_out))`
* where `fan_in` is the number of input units in the weight tensor
* and `fan_out` is the number of output units in the weight tensor.
*
* Reference:
* Glorot & Bengio, AISTATS 2010
* http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function glorotNormal(args) {
return new GlorotNormal(args);
}
/**
* He normal initializer.
*
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(2 / fanIn)`
* where `fanIn` is the number of input units in the weight tensor.
*
* Reference:
* He et al., http://arxiv.org/abs/1502.01852
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function heNormal(args) {
return new HeNormal(args);
}
/**
* He uniform initializer.
*
* It draws samples from a uniform distribution within [-limit, limit]
* where `limit` is `sqrt(6 / fan_in)`
* where `fanIn` is the number of input units in the weight tensor.
*
* Reference:
* He et al., http://arxiv.org/abs/1502.01852
*
* @doc {heading: 'Initializers',namespace: 'initializers'}
*/
function heUniform(args) {
return new HeUniform(args);
}
/**
* LeCun normal initializer.
*
* It draws samples from a truncated normal distribution centered on 0
* with `stddev = sqrt(1 / fanIn)`
* where `fanIn` is the number of input units in the weight tensor.
*
* References:
* [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
* [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function leCunNormal(args) {
return new LeCunNormal(args);
}
/**
* LeCun uniform initializer.
*
* It draws samples from a uniform distribution in the interval
* `[-limit, limit]` with `limit = sqrt(3 / fanIn)`,
* where `fanIn` is the number of input units in the weight tensor.
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function leCunUniform(args) {
return new LeCunUniform(args);
}
/**
* Initializer that generates a random orthogonal matrix.
*
* Reference:
* [Saxe et al., http://arxiv.org/abs/1312.6120](http://arxiv.org/abs/1312.6120)
*
* @doc {heading: 'Initializers', namespace: 'initializers'}
*/
function orthogonal(args) {
return new Orthogonal(args);
}
var exports_initializers = {
__proto__: null,
zeros: zeros$1,
ones: ones$2,
constant: constant,
randomUniform: randomUniform$1,
randomNormal: randomNormal$2,
truncatedNormal: truncatedNormal$1,
identity: identity,
varianceScaling: varianceScaling,
glorotUniform: glorotUniform,
glorotNormal: glorotNormal,
heNormal: heNormal,
heUniform: heUniform,
leCunNormal: leCunNormal,
leCunUniform: leCunUniform,
orthogonal: orthogonal
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Utilities related to persistent state in the backend.
*/
/**
* An ID to track `tf.SymbolicTensor`s and derived classes.
* Required in different places in engine/topology.ts to identify unique
* tensors.
*/
var _nextUniqueTensorId = 0;
function getNextUniqueTensorId() {
return _nextUniqueTensorId++;
}
var _uidPrefixes = {};
/**
* Provides a unique UID given a string prefix.
*
* @param prefix
*/
function getUid(prefix) {
if (prefix === void 0) {
prefix = '';
}
if (!(prefix in _uidPrefixes)) {
_uidPrefixes[prefix] = 0;
}
_uidPrefixes[prefix] += 1;
return prefix + _uidPrefixes[prefix].toString();
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Determine whether the input is an Array of Shapes.
*/
function isArrayOfShapes(x) {
return Array.isArray(x) && Array.isArray(x[0]);
}
/**
* Special case of normalizing shapes to lists.
*
* @param x A shape or list of shapes to normalize into a list of Shapes.
* @return A list of Shapes.
*/
function normalizeShapeList(x) {
if (x.length === 0) {
return [];
}
if (!Array.isArray(x[0])) {
return [x];
}
return x;
}
/**
* Helper function to obtain exactly one Tensor.
* @param xs: A single `tf.Tensor` or an `Array` of `tf.Tensor`s.
* @return A single `tf.Tensor`. If `xs` is an `Array`, return the first one.
* @throws ValueError: If `xs` is an `Array` and its length is not 1.
*/
function getExactlyOneTensor(xs) {
var x;
if (Array.isArray(xs)) {
if (xs.length !== 1) {
throw new ValueError("Expected Tensor length to be 1; got " + xs.length);
}
x = xs[0];
} else {
x = xs;
}
return x;
}
/**
* Helper function to obtain exactly on instance of Shape.
*
* @param shapes Input single `Shape` or Array of `Shape`s.
* @returns If input is a single `Shape`, return it unchanged. If the input is
* an `Array` containing exactly one instance of `Shape`, return the instance.
* Otherwise, throw a `ValueError`.
* @throws ValueError: If input is an `Array` of `Shape`s, and its length is not
* 1.
*/
function getExactlyOneShape(shapes) {
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
if (shapes.length === 1) {
shapes = shapes;
return shapes[0];
} else {
throw new ValueError("Expected exactly 1 Shape; got " + shapes.length);
}
} else {
return shapes;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Count the elements in an Array of LayerVariables.
*
* @param weights: The LayerVariables of which the constituent numbers are to
* be counted.
* @returns A count of the elements in all the LayerVariables
*/
function countParamsInWeights(weights) {
var count = 0;
for (var _iterator = _createForOfIteratorHelperLoose(weights), _step; !(_step = _iterator()).done;) {
var weight = _step.value;
if (weight.shape.length === 0) {
count += 1;
} else {
count += weight.shape.reduce(function (a, b) {
return a * b;
});
}
}
return count;
}
var DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
/**
* A `tf.layers.LayerVariable` is similar to a `tf.Tensor` in that it has a
* dtype and shape, but its value is mutable. The value is itself represented
* as a`tf.Tensor`, and can be read with the `read()` method and updated with
* the `write()` method.
*/
var LayerVariable = /*#__PURE__*/function () {
/**
* Construct Variable from a `tf.Tensor`.
*
* If not explicitly named, the Variable will be given a name with the
* prefix 'Variable'. Variable names are unique. In the case of name
* collision, suffixies '_<num>' will be added to the name.
*
* @param val Initial value of the Variable.
* @param name Name of the variable. If `null` or `undefined` is provided, it
* will default a name with the prefix 'Variable'.
* @param constraint Optional, projection function to be applied to the
* variable after optimize updates
* @throws ValueError if `name` is `null` or `undefined`.
*/
function LayerVariable(val, dtype, name, trainable, constraint) {
if (dtype === void 0) {
dtype = 'float32';
}
if (name === void 0) {
name = DEFAULT_VARIABLE_NAME_PREFIX;
}
if (trainable === void 0) {
trainable = true;
}
if (constraint === void 0) {
constraint = null;
}
this.dtype = dtype == null ? 'float32' : dtype;
this.shape = val.shape;
this.id = getNextUniqueTensorId();
name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
this.trainable_ = trainable;
this.constraint = constraint;
this.val = variable(val, this.trainable_, this.name, this.dtype);
}
/**
* Get a snapshot of the Variable's value.
*
* The returned value is a snapshot of the Variable's value at the time of
* the invocation. Future mutations in the value of the tensor will only
* be reflected by future calls to this method.
*/
var _proto = LayerVariable.prototype;
_proto.read = function read() {
this.assertNotDisposed();
return this.val;
}
/**
* Update the value of the Variable.
*
* @param newVal: The new value to update to. Must be consistent with the
* dtype and shape of the Variable.
* @return This Variable.
*/
;
_proto.write = function write(newVal) {
// TODO(cais): Once TF.js Core supports Tensor.dtype, check dtype match.
this.assertNotDisposed();
checkShapesMatch(this.val, newVal); // Skip updating if this is the exact same tensor.
if (this.val.id !== newVal.id) {
this.val.assign(newVal);
if (this.constraint != null) {
this.val.assign(this.constraint.apply(this.val));
}
}
return this;
}
/**
* Dispose this LayersVariable instance from memory.
*/
;
_proto.dispose = function dispose() {
this.assertNotDisposed();
this.val.dispose();
};
_proto.assertNotDisposed = function assertNotDisposed() {
if (this.val.isDisposed) {
throw new Error("LayersVariable " + this.name + " is already disposed.");
}
};
_createClass(LayerVariable, [{
key: "trainable",
get: function get() {
return this.trainable_;
},
set: function set(trainable) {
this.trainable_ = trainable;
this.val.trainable = trainable;
}
}]);
return LayerVariable;
}();
function checkShapesMatch(x, y) {
if (x.shape.toString() !== y.shape.toString()) {
throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' + JSON.stringify(y.shape));
}
}
/**
* Create a Variable.
* @param x The initial value of the `Variable`.
* @param dtype optional, the type of the variable.
* @param name optional, the name of the variable, default provided by
* Variable.
* @param constraint optional, a constraint to be applied after every update.
* @return The newly instantiated `Variable`.
*/
function variable$1(x, dtype, name, constraint) {
return new LayerVariable(x, dtype, name, true, constraint);
}
/**
* Instantiates an all-zeros Variable and returns it.
*
* @param shape Shape of the tensor.
* @param dtype DType of the tensor.
* @param name Name of the tensor.
* @return An all-zero Variable.
*/
function zerosVariable(shape, dtype, name) {
// TODO(cais): Implement logic for dtype.
return new LayerVariable(zeros(shape), dtype, name);
}
/**
* Instantiates an all-zeros tensor of the same shape as another tensor.
*
* @param x The other tensor.
* @param dtype DType of the tensor.
* @param name Name of the tensor.
* @return A newly instantiated Variable.
*/
function zerosLike$1(x, dtype, name) {
return new LayerVariable(zerosLike(x), dtype, name);
}
/**
* Instantiates an all-ones tensor and returns it.
*
* @param shape Shape of the tensor.
* @param dtype DType of the tensor.
* @param name Name of the tensor.
* @return An all-ones Variable.
*/
function onesVariable(shape, dtype, name) {
// TODO(cais): Implement logic for dtype.
var allocated = ones$1(shape);
return new LayerVariable(allocated, dtype, name);
}
/**
* Instantiates an all-ones tensor of the same shape as another tensor.
*
* @param x The other tensor.
* @param dtype DType of the tensor.
* @param name Name of the tensor.
* @return A newly instantiated Variable.
*/
function onesLike$1(x, dtype, name) {
var allocated = onesLike(x);
return new LayerVariable(allocated, dtype, name);
}
/**
* Instantiate an identity matrix and returns it, as a Variable
*
* @param size Number of rows/columns.
* @param dtype Data type of returned Variable.
* @param name Name of returned Variable.
* @return A Variable, an identity matrix.
*/
function eyeVariable(size, dtype, name) {
return new LayerVariable(eye(size), dtype, name);
}
/**
* Get a Variable with uniform distribution of values.
* @param shape Shape of the tensor.
* @param minval Lower bound of the uniform distribution.
* @param maxval Upper bound of the uniform distribution.
* @param dtype
* @param seed
* @param name Optional name.
* @return The uniform-random Variable.
*/
function randomUniformVariable(shape, minval, maxval, dtype, seed, name) {
if (name === void 0) {
name = 'randomUniform';
}
return new LayerVariable(randomUniform(shape, minval, maxval, dtype), dtype, name);
}
/**
* Get a Variable with truncated-normal distribution of values.
* @param shape Shape of the tensor.
* @param mean mean value of the normal distribution.
* @param stddev standard deviation of the normal distribution.
* @param dtype
* @param seed
* @param name Optional name.
* @return The truncated-normal-random Variable.
*/
function truncatedNormalVariable(shape, mean, stddev, dtype, seed, name) {
if (mean === void 0) {
mean = 0.0;
}
if (stddev === void 0) {
stddev = 1.0;
}
if (name === void 0) {
name = 'truncatedNormal';
}
// TODO(cais): Implement logic for dtype and seed once they are supported
// by deeplearn.js.
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError("randomNormal does not support dType " + dtype + ".");
}
return new LayerVariable(truncatedNormal(shape, mean, stddev, dtype, seed), dtype, name);
}
/**
* Get a Variable with normal distribution of values.
* @param shape Shape of the tensor.
* @param mean mean value of the normal distribution.
* @param stddev standard deviation of the normal distribution.
* @param dtype
* @param seed
* @param name Optional name.
* @return The truncated-normal-random Variable.
*/
function randomNormalVariable(shape, mean, stddev, dtype, seed, name) {
if (mean === void 0) {
mean = 0.0;
}
if (stddev === void 0) {
stddev = 1.0;
}
if (name === void 0) {
name = 'randomNormal';
}
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError("randomNormalVariable does not support dType " + dtype + ".");
}
return new LayerVariable(randomNormal(shape, mean, stddev, dtype, seed), dtype, name);
}
/**
* Update the value of a Variable.
* @param x The Variable to be updated.
* @param xNew The new value to update to.
* @return The Variable updated.
*/
function update(x, xNew) {
return x.write(xNew);
}
/**
* Update the value of a Variable by adding an increment.
* @param x The Variable to be updated.
* @param increment The incrment to add to `x`.
* @return The Variable updated.
*/
function updateAdd(x, increment) {
return x.write(add$1(x.read(), increment));
}
/**
* Update the value of a Variable by subtracting a decrement.
* @param x The Variable to be updated.
* @param decrement The decrement to subtract from `x`.
* @return The Variable updated.
*/
function updateSub(x, decrement) {
return x.write(sub(x.read(), decrement));
}
/**
* Get the values of an array of Variables.
*
* @param tensors An `Array` of `Variable`s to get the values of.
* @return The values of the inputs, as an `Array` of`tf.Tensor`s.
*/
function batchGetValue(xs) {
return xs.map(function (x) {
return x.read();
});
}
/**
* Update the value of multiple Variables at once.
*
* @param variablesAndValues An `Array`, each element is of type
* [Variable, Tensor]. The first item is the
* `Variable` of which the value is to be updated. The second item
* carries the new value.
*/
function batchSetValue(variablesAndValues) {
variablesAndValues.forEach(function (variableAndValue) {
var variable = variableAndValue[0];
variable.write(variableAndValue[1]);
});
}
/**
* Returns the gradients of `variables` w.r.t. the return value of `lossFn`.
* @param lossFn A function which returns a Scalar to be used as the function
* value (i.e., numerator) for differentiation.
* @param variables List of variables to be used as the independent variables
* (i.e., denominator) for differentiation.
* @returns An Array of gradients tensors.
*/
function gradients(lossFn, variables) {
// TODO(cais): The return type signature can be simplified if deeplearn makes
// the corresponding type public.
var variableList = variables.map(function (variable) {
return variable.read();
});
var valudAndGrads = variableGrads(lossFn, variableList);
return variables.map(function (variable) {
return valudAndGrads.grads[variable.name];
});
}
/**
* Specifies the ndim, dtype and shape of every input to a layer.
*
* Every layer should expose (if appropriate) an `inputSpec` attribute:
* a list of instances of InputSpec (one per input tensor).
*
* A null entry in a shape is compatible with any dimension,
* a null shape is compatible with any shape.
*/
var InputSpec = function InputSpec(args) {
this.dtype = args.dtype;
this.shape = args.shape;
/*
TODO(michaelterry): Could throw error if ndim and shape are both defined
(then backport).
*/
if (args.shape != null) {
this.ndim = args.shape.length;
} else {
this.ndim = args.ndim;
}
this.maxNDim = args.maxNDim;
this.minNDim = args.minNDim;
this.axes = args.axes || {};
};
/**
* `tf.SymbolicTensor` is a placeholder for a Tensor without any concrete value.
*
* They are most often encountered when building a graph of `Layer`s for a
* a `tf.LayersModel` and the input data's shape, but not values are known.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
var SymbolicTensor =
/**
*
* @param dtype
* @param shape
* @param sourceLayer The Layer that produced this symbolic tensor.
* @param inputs The inputs passed to sourceLayer's __call__() method.
* @param nodeIndex
* @param tensorIndex
* @param callArgs The keyword arguments passed to the __call__() method.
* @param name
* @param outputTensorIndex The index of this tensor in the list of outputs
* returned by apply().
*/
function SymbolicTensor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
this.dtype = dtype;
this.shape = shape;
this.sourceLayer = sourceLayer;
this.inputs = inputs;
this.callArgs = callArgs;
this.outputTensorIndex = outputTensorIndex;
this.id = getNextUniqueTensorId();
if (name != null) {
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
}
this.rank = shape.length;
};
var _nextNodeID = 0;
/**
* A `Node` describes the connectivity between two layers.
*
* Each time a layer is connected to some new input,
* a node is added to `layer.inboundNodes`.
*
* Each time the output of a layer is used by another layer,
* a node is added to `layer.outboundNodes`.
*
* `nodeIndices` and `tensorIndices` are basically fine-grained coordinates
* describing the origin of the `inputTensors`, verifying the following:
*
* `inputTensors[i] ==
* inboundLayers[i].inboundNodes[nodeIndices[i]].outputTensors[
* tensorIndices[i]]`
*
* A node from layer A to layer B is added to:
* A.outboundNodes
* B.inboundNodes
*/
var Node = /*#__PURE__*/function () {
function Node(args, // TODO(michaelterry): Define actual type for this.
callArgs) {
this.callArgs = callArgs;
this.id = _nextNodeID++;
/*
Layer instance (NOT a list).
this is the layer that takes a list of input tensors
and turns them into a list of output tensors.
the current node will be added to
the inboundNodes of outboundLayer.
*/
this.outboundLayer = args.outboundLayer;
/*
The following 3 properties describe where
the input tensors come from: which layers,
and for each layer, which node and which
tensor output of each node.
*/
// List of layer instances.
this.inboundLayers = args.inboundLayers; // List of integers, 1:1 mapping with inboundLayers.
this.nodeIndices = args.nodeIndices; // List of integers, 1:1 mapping with inboundLayers.
this.tensorIndices = args.tensorIndices;
/*
Following 2 properties:
tensor inputs and outputs of outboundLayer.
*/
// List of tensors. 1:1 mapping with inboundLayers.
this.inputTensors = args.inputTensors; // List of tensors, created by outboundLayer.call().
this.outputTensors = args.outputTensors;
/*
Following 2 properties: input and output masks.
List of tensors, 1:1 mapping with inputTensor.
*/
this.inputMasks = args.inputMasks; // List of tensors, created by outboundLayer.computeMask().
this.outputMasks = args.outputMasks; // Following 2 properties: input and output shapes.
// List of shape tuples, shapes of inputTensors.
this.inputShapes = args.inputShapes; // List of shape tuples, shapes of outputTensors.
this.outputShapes = args.outputShapes; // Add nodes to all layers involved.
for (var _iterator = _createForOfIteratorHelperLoose(args.inboundLayers), _step; !(_step = _iterator()).done;) {
var layer = _step.value;
if (layer != null) {
layer.outboundNodes.push(this);
}
}
args.outboundLayer.inboundNodes.push(this);
}
var _proto = Node.prototype;
_proto.getConfig = function getConfig() {
var inboundNames = [];
for (var _iterator2 = _createForOfIteratorHelperLoose(this.inboundLayers), _step2; !(_step2 = _iterator2()).done;) {
var layer = _step2.value;
if (layer != null) {
inboundNames.push(layer.name);
} else {
inboundNames.push(null);
}
}
return {
outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
inboundLayers: inboundNames,
nodeIndices: this.nodeIndices,
tensorIndices: this.tensorIndices
};
};
return Node;
}();
var _nextLayerID = 0;
/**
* A layer is a grouping of operations and weights that can be composed to
* create a `tf.LayersModel`.
*
* Layers are constructed by using the functions under the
* [tf.layers](#Layers-Basic) namespace.
*
* @doc {heading: 'Layers', subheading: 'Classes', namespace: 'layers'}
*/
var Layer = /*#__PURE__*/function (_serialization$Serial) {
_inheritsLoose(Layer, _serialization$Serial);
function Layer(args) {
var _this;
if (args === void 0) {
args = {};
}
_this = _serialization$Serial.call(this) || this;
_this._callHook = null;
_this._addedWeightNames = []; // Porting Notes: PyKeras does not have this property in this base Layer
// class. Instead lets Layer subclass set it dynamically and checks the
// value with `hasattr`. In tfjs-layers, we let this be a member of this
// base class.
_this._stateful = false;
_this.id = _nextLayerID++;
_this.activityRegularizer = null;
_this.inputSpec = null;
_this.supportsMasking = false; // These properties will be set upon call of this.build()
_this._trainableWeights = [];
_this._nonTrainableWeights = [];
_this._losses = [];
_this._updates = [];
_this._built = false;
/*
These lists will be filled via successive calls
to this.addInboundNode().
*/
_this.inboundNodes = [];
_this.outboundNodes = [];
var name = args.name;
if (!name) {
var prefix = _this.getClassName();
name = toSnakeCase(prefix) + '_' + getUid(prefix);
}
_this.name = name;
_this.trainable_ = args.trainable == null ? true : args.trainable;
if (args.inputShape != null || args.batchInputShape != null) {
/*
In this case we will later create an input layer
to insert before the current layer
*/
var batchInputShape;
if (args.batchInputShape != null) {
batchInputShape = args.batchInputShape;
} else if (args.inputShape != null) {
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
batchInputShape = [batchSize].concat(args.inputShape);
}
_this.batchInputShape = batchInputShape; // Set dtype.
var dtype = args.dtype;
if (dtype == null) {
dtype = args.inputDType;
}
if (dtype == null) {
dtype = 'float32';
}
_this.dtype = dtype;
}
if (args.weights != null) {
_this.initialWeights = args.weights;
} else {
_this.initialWeights = null;
} // The value of `_refCount` is initialized to null. When the layer is used
// in a symbolic way for the first time, it will be set to 1.
_this._refCount = null;
_this.fastWeightInitDuringBuild = false;
return _this;
}
/**
* Converts a layer and its index to a unique (immutable type) name.
* This function is used internally with `this.containerNodes`.
* @param layer The layer.
* @param nodeIndex The layer's position (e.g. via enumerate) in a list of
* nodes.
*
* @returns The unique name.
*/
Layer.nodeKey = function nodeKey(layer, nodeIndex) {
return layer.name + '_ib-' + nodeIndex.toString();
}
/**
* Returns this.inboundNode at index nodeIndex.
*
* Porting note: This is a replacement for _get_node_attribute_at_index()
* @param nodeIndex
* @param attrName The name of the attribute related to request for this node.
*/
;
var _proto2 = Layer.prototype;
_proto2.getNodeAtIndex = function getNodeAtIndex(nodeIndex, attrName) {
if (this.inboundNodes.length === 0) {
throw new RuntimeError('The layer has never been called ' + ("and thus has no defined " + attrName + "."));
}
if (this.inboundNodes.length <= nodeIndex) {
throw new ValueError("Asked to get " + attrName + " at node " + nodeIndex + ", " + ("but the layer has only " + this.inboundNodes.length + " inbound nodes."));
}
return this.inboundNodes[nodeIndex];
}
/**
* Retrieves the input tensor(s) of a layer at a given node.
*
* @param nodeIndex Integer, index of the node from which to retrieve the
* attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
* was called.
*
* @return A tensor (or list of tensors if the layer has multiple inputs).
*/
;
_proto2.getInputAt = function getInputAt(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
}
/**
* Retrieves the output tensor(s) of a layer at a given node.
*
* @param nodeIndex Integer, index of the node from which to retrieve the
* attribute. E.g. `nodeIndex=0` will correspond to the first time the layer
* was called.
*
* @return A tensor (or list of tensors if the layer has multiple outputs).
*/
;
_proto2.getOutputAt = function getOutputAt(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
} // Properties
/**
* Retrieves the input tensor(s) of a layer.
*
* Only applicable if the layer has exactly one inbound node,
* i.e. if it is connected to one incoming layer.
*
* @return Input tensor or list of input tensors.
*
* @exception AttributeError if the layer is connected to more than one
* incoming layers.
*/
;
/**
* Retrieves the Layer's current loss values.
*
* Used for regularizers during training.
*/
_proto2.calculateLosses = function calculateLosses() {
// Porting Node: This is an augmentation to Layer.loss in PyKeras.
// In PyKeras, Layer.loss returns symbolic tensors. Here a concrete
// Tensor (specifically Scalar) values are returned. This is due to the
// imperative backend.
return this.losses.map(function (lossFn) {
return lossFn();
});
};
/**
* Reset the states of the layer.
*
* This method of the base Layer class is essentially a no-op.
* Subclasses that are stateful (e.g., stateful RNNs) should override this
* method.
*/
_proto2.resetStates = function resetStates() {
if (!this.stateful) {
throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' + 'object.');
}
}
/**
* Checks compatibility between the layer and provided inputs.
*
* This checks that the tensor(s) `input`
* verify the input assumptions of the layer
* (if any). If not, exceptions are raised.
*
* @param inputs Input tensor or list of input tensors.
*
* @exception ValueError in case of mismatch between
* the provided inputs and the expectations of the layer.
*/
;
_proto2.assertInputCompatibility = function assertInputCompatibility(inputs) {
inputs = toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
var inputSpec = toList(this.inputSpec);
if (inputs.length !== inputSpec.length) {
throw new ValueError("Layer " + this.name + " expects " + inputSpec.length + " inputs, " + ("but it received " + inputs.length + " input tensors. ") + ("Input received: " + inputs));
}
for (var inputIndex = 0; inputIndex < inputs.length; inputIndex++) {
var x = inputs[inputIndex];
var spec = inputSpec[inputIndex];
if (spec == null) {
continue;
} // Check ndim.
var ndim = x.rank;
if (spec.ndim != null) {
if (ndim !== spec.ndim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + ": " + ("expected ndim=" + spec.ndim + ", found ndim=" + ndim));
}
}
if (spec.maxNDim != null) {
if (ndim > spec.maxNDim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + (": expected max_ndim=" + spec.maxNDim + ", found ndim=" + ndim));
}
}
if (spec.minNDim != null) {
if (ndim < spec.minNDim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + (": expected min_ndim=" + spec.minNDim + ", found ndim=" + ndim + "."));
}
} // Check dtype.
if (spec.dtype != null) {
if (x.dtype !== spec.dtype) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + this.name + " " + (": expected dtype=" + spec.dtype + ", found dtype=" + x.dtype + "."));
}
} // Check specific shape axes.
if (spec.axes) {
var xShape = x.shape;
for (var key in spec.axes) {
var axis = Number(key);
var value = spec.axes[key]; // Perform Python-style slicing in case axis < 0;
// TODO(cais): Use https://github.com/alvivi/typescript-underscore to
// ensure type safety through Underscore calls.
var xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + (this.name + ": expected axis " + axis + " of input shape to ") + ("have value " + value + " but got shape " + xShape + "."));
}
}
} // Check shape.
if (spec.shape != null) {
for (var i = 0; i < spec.shape.length; ++i) {
var specDim = spec.shape[i];
var dim = x.shape[i];
if (specDim != null && dim != null) {
if (specDim !== dim) {
throw new ValueError("Input " + inputIndex + " is incompatible with layer " + (this.name + ": expected shape=" + spec.shape + ", ") + ("found shape=" + x.shape + "."));
}
}
}
}
}
}
/**
* This is where the layer's logic lives.
*
* @param inputs Input tensor, or list/tuple of input tensors.
* @param kwargs Additional keyword arguments.
*
* @return A tensor or list/tuple of tensors.
*/
;
_proto2.call = function call(inputs, kwargs) {
return inputs;
};
_proto2.invokeCallHook = function invokeCallHook(inputs, kwargs) {
if (this._callHook != null) {
this._callHook(inputs, kwargs);
}
}
/**
* Set call hook.
* This is currently used for testing only.
* @param callHook
*/
;
_proto2.setCallHook = function setCallHook(callHook) {
this._callHook = callHook;
}
/**
* Clear call hook.
* This is currently used for testing only.
*/
;
_proto2.clearCallHook = function clearCallHook() {
this._callHook = null;
}
/**
* Builds or executes a `Layer's logic.
*
* When called with `tf.Tensor`(s), execute the `Layer`s computation and
* return Tensor(s). For example:
*
* ```js
* const denseLayer = tf.layers.dense({
* units: 1,
* kernelInitializer: 'zeros',
* useBias: false
* });
*
* // Invoke the layer's apply() method with a `tf.Tensor` (with concrete
* // numeric values).
* const input = tf.ones([2, 2]);
* const output = denseLayer.apply(input);
*
* // The output's value is expected to be [[0], [0]], due to the fact that
* // the dense layer has a kernel initialized to all-zeros and does not have
* // a bias.
* output.print();
* ```
*
* When called with `tf.SymbolicTensor`(s), this will prepare the layer for
* future execution. This entails internal book-keeping on shapes of
* expected Tensors, wiring layers together, and initializing weights.
*
* Calling `apply` with `tf.SymbolicTensor`s are typically used during the
* building of non-`tf.Sequential` models. For example:
*
* ```js
* const flattenLayer = tf.layers.flatten();
* const denseLayer = tf.layers.dense({units: 1});
*
* // Use tf.layers.input() to obtain a SymbolicTensor as input to apply().
* const input = tf.input({shape: [2, 2]});
* const output1 = flattenLayer.apply(input);
*
* // output1.shape is [null, 4]. The first dimension is the undetermined
* // batch size. The second dimension comes from flattening the [2, 2]
* // shape.
* console.log(JSON.stringify(output1.shape));
*
* // The output SymbolicTensor of the flatten layer can be used to call
* // the apply() of the dense layer:
* const output2 = denseLayer.apply(output1);
*
* // output2.shape is [null, 1]. The first dimension is the undetermined
* // batch size. The second dimension matches the number of units of the
* // dense layer.
* console.log(JSON.stringify(output2.shape));
*
* // The input and output and be used to construct a model that consists
* // of the flatten and dense layers.
* const model = tf.model({inputs: input, outputs: output2});
* ```
*
* @param inputs a `tf.Tensor` or `tf.SymbolicTensor` or an Array of them.
* @param kwargs Additional keyword arguments to be passed to `call()`.
*
* @return Output of the layer's `call` method.
*
* @exception ValueError error in case the layer is missing shape information
* for its `build` call.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
// Porting Note: This is a replacement for __call__() in Python.
;
_proto2.apply = function apply(inputs, kwargs) {
var _this2 = this;
kwargs = kwargs || {};
this.assertNotDisposed(); // Ensure inputs are all the same type.
var inputsList = toList(inputs);
var allAreSymbolic = true;
for (var _iterator3 = _createForOfIteratorHelperLoose(inputsList), _step3; !(_step3 = _iterator3()).done;) {
var input = _step3.value;
if (!(input instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
var noneAreSymbolic = true;
for (var _iterator4 = _createForOfIteratorHelperLoose(inputsList), _step4; !(_step4 = _iterator4()).done;) {
var _input = _step4.value;
if (_input instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError('Arguments to apply() must be all ' + 'SymbolicTensors or all Tensors');
} // TODO(michaelterry): nameScope() may not be necessary.
return nameScope(this.name, function () {
// Handle laying building (weight creating, input spec locking).
if (!_this2.built) {
/*
Throw exceptions in case the input is not compatible
with the inputSpec specified in the layer constructor.
*/
_this2.assertInputCompatibility(inputs); // Collect input shapes to build layer.
var inputShapes = [];
for (var _iterator5 = _createForOfIteratorHelperLoose(toList(inputs)), _step5; !(_step5 = _iterator5()).done;) {
var xElem = _step5.value;
inputShapes.push(xElem.shape);
}
_this2.build(singletonOrArray(inputShapes));
_this2.built = true; // Load weights that were specified at layer instantiation.
if (_this2.initialWeights) {
_this2.setWeights(_this2.initialWeights);
}
if (_this2._refCount === null && noneAreSymbolic) {
// The first use of this layer is a non-symbolic call, set ref count
// to 1 so the Layer can be properly disposed if its dispose() method
// is called.
_this2._refCount = 1;
}
}
/*
Throw exceptions in case the input is not compatible
with the inputSpec set at build time.
*/
_this2.assertInputCompatibility(inputs); // Handle mask propagation.
// TODO(michaelterry): Mask propagation not currently implemented.
// Actually call the layer, collecting output(s), mask(s), and shape(s).
if (noneAreSymbolic) {
var output = _this2.call(inputs, kwargs); // TODO(michaelterry): Compute the outputMask
// If the layer returns tensors from its inputs, unmodified,
// we copy them to avoid loss of tensor metadata.
var outputList = toList(output);
var outputListCopy = []; // TODO(michaelterry): This copying may not be necessary given our eager
// backend.
for (var _iterator6 = _createForOfIteratorHelperLoose(outputList), _step6; !(_step6 = _iterator6()).done;) {
var x = _step6.value;
if (inputsList.indexOf(x) !== -1) {
x = x.clone();
}
outputListCopy.push(x);
}
output = singletonOrArray(outputListCopy);
if (_this2.activityRegularizer != null) {
throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.');
} // TODO(michaelterry): Call addInboundNode()?
return output;
} else {
var inputShape = collectInputShape(inputs);
var outputShape = _this2.computeOutputShape(inputShape);
var _output;
var outputDType = guessOutputDType(inputs);
_this2.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] : inputShape);
if (outputShape != null && outputShape.length > 0 && Array.isArray(outputShape[0])) {
// We have multiple output shapes. Create multiple output tensors.
_output = outputShape.map(function (shape, index) {
return new SymbolicTensor(outputDType, shape, _this2, toList(inputs), kwargs, _this2.name, index);
});
} else {
_output = new SymbolicTensor(outputDType, outputShape, _this2, toList(inputs), kwargs, _this2.name);
}
/*
Add an inbound node to the layer, so that it keeps track
of the call and of all new variables created during the call.
This also updates the layer history of the output tensor(s).
If the input tensor(s) had no previous history,
this does nothing.
*/
_this2.addInboundNode(inputs, _output, null, null, inputShape, outputShape, kwargs);
_this2._refCount++;
if (_this2.activityRegularizer != null) {
throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.');
}
return _output;
}
});
}
/**
* Check compatibility between input shape and this layer's batchInputShape.
*
* Print warning if any incompatibility is found.
*
* @param inputShape Input shape to be checked.
*/
;
_proto2.warnOnIncompatibleInputShape = function warnOnIncompatibleInputShape(inputShape) {
if (this.batchInputShape == null) {
return;
} else if (inputShape.length !== this.batchInputShape.length) {
console.warn("The rank of the input tensor provided (shape: " + (JSON.stringify(inputShape) + ") does not match that of the ") + ("batchInputShape (" + JSON.stringify(this.batchInputShape) + ") ") + ("of the layer " + this.name));
} else {
var dimMismatch = false;
this.batchInputShape.forEach(function (dimension, i) {
if (dimension != null && inputShape[i] != null && inputShape[i] !== dimension) {
dimMismatch = true;
}
});
if (dimMismatch) {
console.warn("The shape of the input tensor " + ("(" + JSON.stringify(inputShape) + ") does not ") + ("match the expectation of layer " + this.name + ": ") + ("" + JSON.stringify(this.batchInputShape)));
}
}
}
/**
* Retrieves the output shape(s) of a layer.
*
* Only applicable if the layer has only one inbound node, or if all inbound
* nodes have the same output shape.
*
* @returns Output shape or shapes.
* @throws AttributeError: if the layer is connected to more than one incoming
* nodes.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
/**
* Counts the total number of numbers (e.g., float32, int32) in the
* weights.
*
* @returns An integer count.
* @throws RuntimeError: If the layer is not built yet (in which case its
* weights are not defined yet.)
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
_proto2.countParams = function countParams() {
if (!this.built) {
throw new RuntimeError("You tried to call countParams() on " + this.name + ", " + "but the layer is not built yet. Build it first by calling " + "build(batchInputShape).");
}
return countParamsInWeights(this.weights);
}
/**
* Creates the layer weights.
*
* Must be implemented on all layers that have weights.
*
* Called when apply() is called to construct the weights.
*
* @param inputShape A `Shape` or array of `Shape` (unused).
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.build = function build(inputShape) {
this.built = true;
}
/**
* Returns the current values of the weights of the layer.
*
* @param trainableOnly Whether to get the values of only trainable weights.
* @returns Weight values as an `Array` of `tf.Tensor`s.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.getWeights = function getWeights(trainableOnly) {
if (trainableOnly === void 0) {
trainableOnly = false;
}
return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
}
/**
* Sets the weights of the layer, from Tensors.
*
* @param weights a list of Tensors. The number of arrays and their shape
* must match number of the dimensions of the weights of the layer (i.e.
* it should match the output of `getWeights`).
*
* @exception ValueError If the provided weights list does not match the
* layer's specifications.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.setWeights = function setWeights(weights) {
var _this3 = this;
tidy(function () {
var params = _this3.weights;
if (params.length !== weights.length) {
// TODO(cais): Restore the following and use `providedWeights`, instead
// of `weights` in the error message, once the deeplearn.js bug is
// fixed: https://github.com/PAIR-code/deeplearnjs/issues/498 const
// providedWeights = JSON.stringify(weights).substr(0, 50);
throw new ValueError("You called setWeights(weights) on layer \"" + _this3.name + "\" " + ("with a weight list of length " + weights.length + ", ") + ("but the layer was expecting " + params.length + " weights. ") + ("Provided weights: " + weights + "..."));
}
if (params.length === 0) {
return;
}
var weightValueTuples = [];
var paramValues = batchGetValue(params);
for (var i = 0; i < paramValues.length; ++i) {
var pv = paramValues[i];
var p = params[i];
var w = weights[i];
if (!arraysEqual(pv.shape, w.shape)) {
throw new ValueError("Layer weight shape " + pv.shape + " " + ("not compatible with provided weight shape " + w.shape));
}
weightValueTuples.push([p, w]);
}
batchSetValue(weightValueTuples);
});
}
/**
* Adds a weight variable to the layer.
*
* @param name Name of the new weight variable.
* @param shape The shape of the weight.
* @param dtype The dtype of the weight.
* @param initializer An initializer instance.
* @param regularizer A regularizer instance.
* @param trainable Whether the weight should be trained via backprop or not
* (assuming that the layer itself is also trainable).
* @param constraint An optional trainable.
* @return The created weight variable.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.addWeight = function addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint) {
// Reject duplicate weight names.
if (this._addedWeightNames.indexOf(name) !== -1) {
throw new ValueError("Duplicate weight name " + name + " for layer " + this.name);
}
this._addedWeightNames.push(name);
if (dtype == null) {
dtype = 'float32';
}
if (this.fastWeightInitDuringBuild) {
initializer = getInitializer('zeros');
}
var initValue = initializer.apply(shape, dtype);
var weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
initValue.dispose(); // Request backend not to dispose the weights of the model on scope() exit.
if (regularizer != null) {
this.addLoss(function () {
return regularizer.apply(weight.read());
});
}
if (trainable == null) {
trainable = true;
}
if (trainable) {
this._trainableWeights.push(weight);
} else {
this._nonTrainableWeights.push(weight);
}
return weight;
}
/**
* Set the fast-weight-initialization flag.
*
* In cases where the initialized weight values will be immediately
* overwritten by loaded weight values during model loading, setting
* the flag to `true` saves unnecessary calls to potentially expensive
* initializers and speeds up the loading process.
*
* @param value Target value of the flag.
*/
;
_proto2.setFastWeightInitDuringBuild = function setFastWeightInitDuringBuild(value) {
this.fastWeightInitDuringBuild = value;
}
/**
* Add losses to the layer.
*
* The loss may potentionally be conditional on some inputs tensors,
* for instance activity losses are conditional on the layer's inputs.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.addLoss = function addLoss(losses) {
if (losses == null || Array.isArray(losses) && losses.length === 0) {
return;
} // Update this.losses
losses = toList(losses);
if (this._losses !== undefined && this._losses !== null) {
var _this$losses;
(_this$losses = this.losses).push.apply(_this$losses, losses);
}
}
/**
* Computes the output shape of the layer.
*
* Assumes that the layer will be built to match that input shape provided.
*
* @param inputShape A shape (tuple of integers) or a list of shape tuples
* (one per output tensor of the layer). Shape tuples can include null for
* free dimensions, instead of an integer.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
}
/**
* Computes an output mask tensor.
*
* @param inputs Tensor or list of tensors.
* @param mask Tensor or list of tensors.
*
* @return null or a tensor (or list of tensors, one per output tensor of the
* layer).
*/
;
_proto2.computeMask = function computeMask(inputs, mask) {
var _this4 = this;
if (!this.supportsMasking) {
if (mask != null) {
if (Array.isArray(mask)) {
mask.forEach(function (maskElement) {
if (maskElement != null) {
throw new TypeError("Layer " + _this4.name + " does not support masking, " + 'but was passed an inputMask.');
}
});
} else {
throw new TypeError("Layer " + this.name + " does not support masking, " + 'but was passed an inputMask.');
}
} // masking not explicitly supported: return null as mask
return null;
} // if masking is explictly supported, by default
// carry over the input mask
return mask;
}
/**
* Internal method to create an inbound node for the layer.
*
* @param inputTensors List of input tensors.
* @param outputTensors List of output tensors.
* @param inputMasks List of input masks (a mask can be a tensor, or null).
* @param outputMasks List of output masks (a mask can be a tensor, or null).
* @param inputShapes List of input shape tuples.
* @param outputShapes List of output shape tuples.
* @param kwargs Dictionary of keyword arguments that were passed to the
* `call` method of the layer at the call that created the node.
*/
;
_proto2.addInboundNode = function addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs) {
if (kwargs === void 0) {
kwargs = null;
}
var inputTensorList = toList(inputTensors);
outputTensors = toList(outputTensors);
inputMasks = toList(inputMasks);
outputMasks = toList(outputMasks);
inputShapes = normalizeShapeList(inputShapes);
outputShapes = normalizeShapeList(outputShapes); // Collect input tensor(s) coordinates.
var inboundLayers = [];
var nodeIndices = [];
var tensorIndices = [];
for (var _iterator7 = _createForOfIteratorHelperLoose(inputTensorList), _step7; !(_step7 = _iterator7()).done;) {
var x = _step7.value;
/*
* TODO(michaelterry): Keras adds this value to tensors; it's not
* clear whether we'll use this or not.
*/
inboundLayers.push(x.sourceLayer);
nodeIndices.push(x.nodeIndex);
tensorIndices.push(x.tensorIndex);
} // Create node, add it to inbound nodes.
// (This call has side effects.)
// tslint:disable-next-line:no-unused-expression
new Node({
outboundLayer: this,
inboundLayers: inboundLayers,
nodeIndices: nodeIndices,
tensorIndices: tensorIndices,
inputTensors: inputTensorList,
outputTensors: outputTensors,
inputMasks: inputMasks,
outputMasks: outputMasks,
inputShapes: inputShapes,
outputShapes: outputShapes
}, kwargs); // Update tensor history
for (var i = 0; i < outputTensors.length; i++) {
// TODO(michaelterry: _uses_learning_phase not tracked.
outputTensors[i].sourceLayer = this;
outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
outputTensors[i].tensorIndex = i;
}
}
/**
* Returns the config of the layer.
*
* A layer config is a TS dictionary (serializable)
* containing the configuration of a layer.
* The same layer can be reinstantiated later
* (without its trained weights) from this configuration.
*
* The config of a layer does not include connectivity
* information, nor the layer class name. These are handled
* by 'Container' (one layer of abstraction above).
*
* Porting Note: The TS dictionary follows TS naming standrds for
* keys, and uses tfjs-layers type-safe Enums. Serialization methods
* should use a helper function to convert to the pythonic storage
* standard. (see serialization_utils.convertTsToPythonic)
*
* @returns TS dictionary of configuration.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.getConfig = function getConfig() {
var config = {
name: this.name,
trainable: this.trainable
};
if (this.batchInputShape != null) {
config['batchInputShape'] = this.batchInputShape;
}
if (this.dtype != null) {
config['dtype'] = this.dtype;
}
return config;
}
/**
* Dispose the weight variables that this Layer instance holds.
*
* @returns {number} Number of disposed variables.
*/
;
_proto2.disposeWeights = function disposeWeights() {
this.weights.forEach(function (weight) {
return weight.dispose();
});
return this.weights.length;
};
_proto2.assertNotDisposed = function assertNotDisposed() {
if (this._refCount === 0) {
throw new Error("Layer '" + this.name + "' is already disposed.");
}
}
/**
* Attempt to dispose layer's weights.
*
* This method decrease the reference count of the Layer object by 1.
*
* A Layer is reference-counted. Its reference count is incremented by 1
* the first item its `apply()` method is called and when it becomes a part
* of a new `Node` (through calling the `apply()`) method on a
* `tf.SymbolicTensor`).
*
* If the reference count of a Layer becomes 0, all the weights will be
* disposed and the underlying memory (e.g., the textures allocated in WebGL)
* will be freed.
*
* Note: If the reference count is greater than 0 after the decrement, the
* weights of the Layer will *not* be disposed.
*
* After a Layer is disposed, it cannot be used in calls such as `apply()`,
* `getWeights()` or `setWeights()` anymore.
*
* @returns A DisposeResult Object with the following fields:
* - refCountAfterDispose: The reference count of the Container after this
* `dispose()` call.
* - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
* during this `dispose()` call.
* @throws {Error} If the layer is not built yet, or if the layer has already
* been disposed.
*
* @doc {heading: 'Models', 'subheading': 'Classes'}
*/
;
_proto2.dispose = function dispose() {
if (!this.built) {
throw new Error("Cannot dispose Layer " + this.name + " because it has not been " + "built yet.");
}
if (this._refCount === null) {
throw new Error("Cannot dispose Layer " + this.name + " because it has not been used " + "yet.");
}
this.assertNotDisposed();
var numDisposedVariables = 0;
if (--this._refCount === 0) {
numDisposedVariables = this.disposeWeights();
}
return {
refCountAfterDispose: this._refCount,
numDisposedVariables: numDisposedVariables
};
};
_createClass(Layer, [{
key: "input",
get: function get() {
if (this.inboundNodes.length > 1) {
throw new AttributeError("Layer " + this.name + ' has multiple inbound nodes, ' + 'hence the notion of "layer input" ' + 'is ill-defined. ' + 'Use `getInputAt(nodeIndex)` instead.');
} else if (this.inboundNodes.length === 0) {
throw new AttributeError("Layer " + this.name + ' is not connected, no input to return.');
}
return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
}
/**
* Retrieves the output tensor(s) of a layer.
*
* Only applicable if the layer has exactly one inbound node,
* i.e. if it is connected to one incoming layer.
*
* @return Output tensor or list of output tensors.
*
* @exception AttributeError if the layer is connected to more than one
* incoming layers.
*/
}, {
key: "output",
get: function get() {
if (this.inboundNodes.length === 0) {
throw new AttributeError("Layer " + this.name + ' has no inbound nodes.');
}
if (this.inboundNodes.length > 1) {
throw new AttributeError("Layer " + this.name + ' has multiple inbound nodes, ' + 'hence the notion of "layer output" ' + 'is ill-defined. ' + 'Use `getOutputAt(nodeIndex)` instead.');
}
return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
}
}, {
key: "losses",
get: function get() {
return this._losses;
}
}, {
key: "updates",
get: function get() {
return this._updates;
}
}, {
key: "built",
get: function get() {
return this._built;
},
set: function set(built) {
this._built = built;
}
}, {
key: "trainable",
get: function get() {
return this.trainable_;
},
set: function set(trainable) {
this._trainableWeights.forEach(function (w) {
return w.trainable = trainable;
});
this.trainable_ = trainable;
}
}, {
key: "trainableWeights",
get: function get() {
if (this.trainable_) {
return this._trainableWeights.filter(function (w) {
return w.trainable;
});
} else {
return [];
}
},
set: function set(weights) {
this._trainableWeights = weights;
}
}, {
key: "nonTrainableWeights",
get: function get() {
if (this.trainable) {
return this._trainableWeights.filter(function (w) {
return !w.trainable;
}).concat(this._nonTrainableWeights);
} else {
return this._trainableWeights.concat(this._nonTrainableWeights);
}
},
set: function set(weights) {
this._nonTrainableWeights = weights;
}
/**
* The concatenation of the lists trainableWeights and nonTrainableWeights
* (in this order).
*/
}, {
key: "weights",
get: function get() {
return this.trainableWeights.concat(this.nonTrainableWeights);
}
}, {
key: "stateful",
get: function get() {
return this._stateful;
}
}, {
key: "outputShape",
get: function get() {
if (this.inboundNodes == null || this.inboundNodes.length === 0) {
throw new AttributeError("The layer " + this.name + " has never been called and thus has no " + "defined output shape.");
}
var allOutputShapes = [];
for (var _iterator8 = _createForOfIteratorHelperLoose(this.inboundNodes), _step8; !(_step8 = _iterator8()).done;) {
var node = _step8.value;
var shapeString = JSON.stringify(node.outputShapes);
if (allOutputShapes.indexOf(shapeString) === -1) {
allOutputShapes.push(shapeString);
}
}
if (allOutputShapes.length === 1) {
var outputShapes = this.inboundNodes[0].outputShapes;
if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) && outputShapes.length === 1) {
return outputShapes[0];
} else {
return outputShapes;
}
} else {
throw new AttributeError("The layer " + this.name + " has multiple inbound nodes with different " + "output shapes. Hence the notion of \"output shape\" is ill-defined " + "for the layer."); // TODO(cais): Implement getOutputShapeAt().
}
}
}]);
return Layer;
}(Serializable);
/**
* Collects the input shape(s) of a list of `tf.Tensor`s or
* `tf.SymbolicTensor`s.
*
* TODO(michaelterry): Update PyKeras docs (backport).
*
* @param inputTensors List of input tensors (or single input tensor).
*
* @return List of shape tuples (or single tuple), one tuple per input.
*/
function collectInputShape(inputTensors) {
inputTensors = toList(inputTensors);
var shapes = [];
for (var _iterator9 = _createForOfIteratorHelperLoose(inputTensors), _step9; !(_step9 = _iterator9()).done;) {
var x = _step9.value;
shapes.push(x.shape);
}
return singletonOrArray(shapes);
}
/**
* Guesses output dtype based on inputs.
*
* At present, just returns 'float32' for any input.
*
* @param inputTensors List of input tensors (or single input tensor).
*
* @return The guessed DType. At present, always returns 'float32'.
*/
function guessOutputDType(inputTensors) {
return 'float32';
}
/**
* Returns the list of input tensors necessary to compute `tensor`.
*
* Output will always be a list of tensors (potentially with 1 element).
*
* @param tensor The tensor to start from.
* @param layer Origin layer of the tensor.
* @param nodeIndex Origin node index of the tensor.
*
* @return Array of input tensors.
*/
function getSourceInputs(tensor, layer, nodeIndex) {
if (layer == null || nodeIndex != null && nodeIndex > 0) {
layer = tensor.sourceLayer;
nodeIndex = tensor.nodeIndex;
}
if (layer.inboundNodes.length === 0) {
return [tensor];
} else {
var node = layer.inboundNodes[nodeIndex];
if (node.inboundLayers.length === 0) {
return node.inputTensors;
} else {
var sourceTensors = [];
for (var i = 0; i < node.inboundLayers.length; i++) {
var x = node.inputTensors[i];
var _layer = node.inboundLayers[i];
var _nodeIndex = node.nodeIndices[i];
var previousSources = getSourceInputs(x, _layer, _nodeIndex); // Avoid input redundancy.
for (var _iterator10 = _createForOfIteratorHelperLoose(previousSources), _step10; !(_step10 = _iterator10()).done;) {
var _x = _step10.value;
if (sourceTensors.indexOf(_x) === -1) {
sourceTensors.push(_x);
}
}
}
return sourceTensors;
}
}
}
var InputLayer = /*#__PURE__*/function (_Layer) {
_inheritsLoose(InputLayer, _Layer);
function InputLayer(args) {
var _this;
_this = _Layer.call(this, {
dtype: args.dtype,
name: args.name != null ? args.name : getUid('input').toString()
}) || this; // Normalize config.batchSize and config.sparse
if (args.batchSize == null) {
args.batchSize = null;
}
if (args.sparse == null) {
args.sparse = false;
}
_this.trainable = false;
_this.built = true;
_this.sparse = args.sparse;
if (args.inputShape != null && args.batchInputShape != null) {
throw new ValueError('Only provide the inputShape OR ' + 'batchInputShape argument to inputLayer, not both at the same time.');
}
var batchInputShape = args.batchInputShape;
if (batchInputShape == null) {
if (args.inputShape == null) {
throw new ValueError('An InputLayer should be passed either a ' + '`batchInputShape` or an `inputShape`.');
} else {
batchInputShape = [args.batchSize].concat(args.inputShape);
}
} else {
// TODO(michaelterry): Backport to PyKeras
if (args.batchSize != null) {
throw new ValueError('Cannot specify batchSize if batchInputShape is ' + 'specified when creating an InputLayer.');
}
}
var dtype = args.dtype || 'float32';
_this.batchInputShape = batchInputShape;
_this.dtype = dtype; // TODO(michaelterry): Backport this to PyKeras?
_this.inputSpec = [{
shape: batchInputShape
}];
var inputTensor = new SymbolicTensor(_this.dtype, _this.batchInputShape, _assertThisInitialized(_this), [], {}, _this.name);
inputTensor.nodeIndex = 0;
inputTensor.tensorIndex = 0; // Create an input node to add to this.outboundNode.
// (This call has side effects.)
// tslint:disable-next-line:no-unused-expression
new Node({
outboundLayer: _assertThisInitialized(_this),
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: [inputTensor],
outputTensors: [inputTensor],
inputMasks: [null],
outputMasks: [null],
inputShapes: [batchInputShape],
outputShapes: [batchInputShape]
});
return _this;
}
var _proto = InputLayer.prototype;
_proto.apply = function apply(inputs, kwargs) {
throw new ValueError('Cannot pass any input to an ' + ("InputLayer's apply() method. InputLayer name: " + this.name));
};
_proto.dispose = function dispose() {
// dispose() for InputLayer is overridden as no-op.
return {
refCountAfterDispose: this._refCount,
numDisposedVariables: 0
};
};
_proto.getConfig = function getConfig() {
return {
batchInputShape: this.batchInputShape,
dtype: this.dtype,
sparse: this.sparse,
name: this.name
};
};
return InputLayer;
}(Layer);
/** @nocollapse */
InputLayer.className = 'InputLayer';
registerClass(InputLayer);
function Input(config) {
if (config.batchShape == null && config.shape == null) {
throw new Error('Please provide to Input either a `shape`' + ' or a `batchShape` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.');
}
if (config.batchShape != null && config.shape != null) {
// TODO(michaelterry): Backport to PyKeras.
throw new ValueError('Please provide either a `shape` or `batchShape` ' + 'argument to Input, but not both.');
}
var batchShape = config.batchShape;
if (config.shape != null && batchShape == null) {
batchShape = [null].concat(config.shape);
}
var dtype = config.dtype;
if (dtype == null) {
dtype = 'float32';
}
var inputLayer = new InputLayer({
batchInputShape: batchShape,
name: config.name,
dtype: dtype,
sparse: config.sparse
});
var outputs = inputLayer.inboundNodes[0].outputTensors;
return outputs[0];
}
/**
* Turn any Scalar values in a Logs object into actual number values.
*
* @param logs The `Logs` object to be resolved in place.
*/
function resolveScalarsInLogs(_x) {
return _resolveScalarsInLogs.apply(this, arguments);
}
/**
* Dispose all Tensors in an UnresolvedLogs object.
*
* @param logs An `UnresolvedLogs` object potentially containing `tf.Tensor`s in
* places where the values can be `tf.Tensor` or `number`.
*/
function _resolveScalarsInLogs() {
_resolveScalarsInLogs = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(logs) {
var promises, keys, scalarsToDispose, key, value, valueScalar, values, i;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(logs == null)) {
_context.next = 2;
break;
}
return _context.abrupt("return");
case 2:
promises = [];
keys = [];
scalarsToDispose = [];
for (key in logs) {
value = logs[key];
if (typeof value !== 'number') {
valueScalar = value;
promises.push(valueScalar.data());
keys.push(key);
scalarsToDispose.push(valueScalar);
}
}
if (!(promises.length > 0)) {
_context.next = 12;
break;
}
_context.next = 9;
return Promise.all(promises);
case 9:
values = _context.sent;
for (i = 0; i < values.length; ++i) {
logs[keys[i]] = values[i][0];
} // Dispose the original scalar tensors.
dispose(scalarsToDispose);
case 12:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _resolveScalarsInLogs.apply(this, arguments);
}
function disposeTensorsInLogs(logs) {
if (logs == null) {
return;
}
for (var key in logs) {
var value = logs[key];
if (typeof value !== 'number') {
value.dispose();
}
}
}
/** Verbosity logging level when fitting a model. */
var ModelLoggingVerbosity;
(function (ModelLoggingVerbosity) {
ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
})(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
/** How often to yield to the main thread when training (in ms). */
var DEFAULT_YIELD_EVERY_MS = 125;
/**
* Abstract base class used to build new callbacks.
*
* The `logs` dictionary that callback methods take as argument will contain
* keys for quantities relevant to the current batch or epoch.
*
* Currently, the `.fit()` method of the `Sequential` model class
* will include the following quantities in the `logs` that
* it passes to its callbacks:
*
* onEpochEnd: Logs include `acc` and `loss`, and optionally include `valLoss`
* (if validation is enabled in `fit`), and `valAcc` (if validation and
* accuracy monitoring are enabled).
* onBatchBegin: Logs include `size`, the number of samples in the current
* batch.
* onBatchEnd: Logs include `loss`, and optionally `acc` (if accuracy monitoring
* is enabled).
*/
var BaseCallback = /*#__PURE__*/function () {
function BaseCallback() {
// TODO(michaelterry): This type is a best guess.
this.validationData = null;
}
var _proto = BaseCallback.prototype;
_proto.setParams = function setParams(params) {
this.params = params;
};
_proto.onEpochBegin = /*#__PURE__*/function () {
var _onEpochBegin = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(epoch, logs) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
case "end":
return _context.stop();
}
}
}, _callee);
}));
function onEpochBegin(_x, _x2) {
return _onEpochBegin.apply(this, arguments);
}
return onEpochBegin;
}();
_proto.onEpochEnd = /*#__PURE__*/function () {
var _onEpochEnd = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(epoch, logs) {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
function onEpochEnd(_x3, _x4) {
return _onEpochEnd.apply(this, arguments);
}
return onEpochEnd;
}();
_proto.onBatchBegin = /*#__PURE__*/function () {
var _onBatchBegin = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(batch, logs) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
function onBatchBegin(_x5, _x6) {
return _onBatchBegin.apply(this, arguments);
}
return onBatchBegin;
}();
_proto.onBatchEnd = /*#__PURE__*/function () {
var _onBatchEnd = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(batch, logs) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
case "end":
return _context4.stop();
}
}
}, _callee4);
}));
function onBatchEnd(_x7, _x8) {
return _onBatchEnd.apply(this, arguments);
}
return onBatchEnd;
}();
_proto.onTrainBegin = /*#__PURE__*/function () {
var _onTrainBegin = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5(logs) {
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
case "end":
return _context5.stop();
}
}
}, _callee5);
}));
function onTrainBegin(_x9) {
return _onTrainBegin.apply(this, arguments);
}
return onTrainBegin;
}();
_proto.onTrainEnd = /*#__PURE__*/function () {
var _onTrainEnd = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee6(logs) {
return regeneratorRuntime.wrap(function _callee6$(_context6) {
while (1) {
switch (_context6.prev = _context6.next) {
case 0:
case "end":
return _context6.stop();
}
}
}, _callee6);
}));
function onTrainEnd(_x10) {
return _onTrainEnd.apply(this, arguments);
}
return onTrainEnd;
}() // LayersModel needs to call Callback.setModel(), but cannot actually depend
// on Callback because that creates a cyclic dependency. Providing this no-op
// method on BaseCallback breaks the cycle: this way LayersModel can depend on
// BaseCallback but not on Callback. The argument is typed as `Container`
// (the superclass of LayersModel) to avoid recapitulating the cycle. Callback
// overrides this method and enforces that the argument is really a
// LayersModel.
;
_proto.setModel = function setModel(model) {// Do nothing. Use Callback instead of BaseCallback to track the model.
};
return BaseCallback;
}();
/**
* Container abstracting a list of callbacks.
*/
var CallbackList = /*#__PURE__*/function () {
// TODO(cais): When the need arises, uncomment the following lines and
// implement the queue for time values.
// private deltaTBatch: number;
// private deltaTsBatchBegin: Array<number>;
// private deltaTsBatchEnd: Array<number>;
/**
* Constructor of CallbackList.
* @param callbacks Array of `Callback` instances.
* @param queueLength Queue length for keeping running statistics over
* callback execution time.
*/
function CallbackList(callbacks, queueLength) {
if (queueLength === void 0) {
queueLength = 10;
}
// TODO(cais): Make use of queueLength when implementing the queue for time
// values.
if (callbacks == null) {
callbacks = [];
}
this.callbacks = callbacks;
this.queueLength = queueLength;
}
var _proto2 = CallbackList.prototype;
_proto2.append = function append(callback) {
this.callbacks.push(callback);
};
_proto2.setParams = function setParams(params) {
for (var _iterator = _createForOfIteratorHelperLoose(this.callbacks), _step; !(_step = _iterator()).done;) {
var callback = _step.value;
callback.setParams(params);
}
};
_proto2.setModel = function setModel(model) {
for (var _iterator2 = _createForOfIteratorHelperLoose(this.callbacks), _step2; !(_step2 = _iterator2()).done;) {
var callback = _step2.value;
callback.setModel(model);
}
}
/**
* Called at the start of an epoch.
* @param epoch Index of epoch.
* @param logs Dictionary of logs.
*/
;
_proto2.onEpochBegin =
/*#__PURE__*/
function () {
var _onEpochBegin2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee7(epoch, logs) {
var _iterator3, _step3, callback;
return regeneratorRuntime.wrap(function _callee7$(_context7) {
while (1) {
switch (_context7.prev = _context7.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator3 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step3 = _iterator3()).done) {
_context7.next = 8;
break;
}
callback = _step3.value;
_context7.next = 6;
return callback.onEpochBegin(epoch, logs);
case 6:
_context7.next = 2;
break;
case 8:
case "end":
return _context7.stop();
}
}
}, _callee7, this);
}));
function onEpochBegin(_x11, _x12) {
return _onEpochBegin2.apply(this, arguments);
}
return onEpochBegin;
}()
/**
* Called at the end of an epoch.
* @param epoch Index of epoch.
* @param logs Dictionary of logs.
*/
;
_proto2.onEpochEnd =
/*#__PURE__*/
function () {
var _onEpochEnd2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee8(epoch, logs) {
var _iterator4, _step4, callback;
return regeneratorRuntime.wrap(function _callee8$(_context8) {
while (1) {
switch (_context8.prev = _context8.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator4 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step4 = _iterator4()).done) {
_context8.next = 8;
break;
}
callback = _step4.value;
_context8.next = 6;
return callback.onEpochEnd(epoch, logs);
case 6:
_context8.next = 2;
break;
case 8:
case "end":
return _context8.stop();
}
}
}, _callee8, this);
}));
function onEpochEnd(_x13, _x14) {
return _onEpochEnd2.apply(this, arguments);
}
return onEpochEnd;
}()
/**
* Called right before processing a batch.
* @param batch Index of batch within the current epoch.
* @param logs Dictionary of logs.
*/
;
_proto2.onBatchBegin =
/*#__PURE__*/
function () {
var _onBatchBegin2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee9(batch, logs) {
var _iterator5, _step5, callback;
return regeneratorRuntime.wrap(function _callee9$(_context9) {
while (1) {
switch (_context9.prev = _context9.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator5 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step5 = _iterator5()).done) {
_context9.next = 8;
break;
}
callback = _step5.value;
_context9.next = 6;
return callback.onBatchBegin(batch, logs);
case 6:
_context9.next = 2;
break;
case 8:
case "end":
return _context9.stop();
}
}
}, _callee9, this);
}));
function onBatchBegin(_x15, _x16) {
return _onBatchBegin2.apply(this, arguments);
}
return onBatchBegin;
}()
/**
* Called at the end of a batch.
* @param batch Index of batch within the current epoch.
* @param logs Dictionary of logs.
*/
;
_proto2.onBatchEnd =
/*#__PURE__*/
function () {
var _onBatchEnd2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee10(batch, logs) {
var _iterator6, _step6, callback;
return regeneratorRuntime.wrap(function _callee10$(_context10) {
while (1) {
switch (_context10.prev = _context10.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator6 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step6 = _iterator6()).done) {
_context10.next = 8;
break;
}
callback = _step6.value;
_context10.next = 6;
return callback.onBatchEnd(batch, logs);
case 6:
_context10.next = 2;
break;
case 8:
case "end":
return _context10.stop();
}
}
}, _callee10, this);
}));
function onBatchEnd(_x17, _x18) {
return _onBatchEnd2.apply(this, arguments);
}
return onBatchEnd;
}()
/**
* Called at the beginning of training.
* @param logs Dictionary of logs.
*/
;
_proto2.onTrainBegin =
/*#__PURE__*/
function () {
var _onTrainBegin2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee11(logs) {
var _iterator7, _step7, callback;
return regeneratorRuntime.wrap(function _callee11$(_context11) {
while (1) {
switch (_context11.prev = _context11.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator7 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step7 = _iterator7()).done) {
_context11.next = 8;
break;
}
callback = _step7.value;
_context11.next = 6;
return callback.onTrainBegin(logs);
case 6:
_context11.next = 2;
break;
case 8:
case "end":
return _context11.stop();
}
}
}, _callee11, this);
}));
function onTrainBegin(_x19) {
return _onTrainBegin2.apply(this, arguments);
}
return onTrainBegin;
}()
/**
* Called at the end of training.
* @param logs Dictionary of logs.
*/
;
_proto2.onTrainEnd =
/*#__PURE__*/
function () {
var _onTrainEnd2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee12(logs) {
var _iterator8, _step8, callback;
return regeneratorRuntime.wrap(function _callee12$(_context12) {
while (1) {
switch (_context12.prev = _context12.next) {
case 0:
if (logs == null) {
logs = {};
}
_iterator8 = _createForOfIteratorHelperLoose(this.callbacks);
case 2:
if ((_step8 = _iterator8()).done) {
_context12.next = 8;
break;
}
callback = _step8.value;
_context12.next = 6;
return callback.onTrainEnd(logs);
case 6:
_context12.next = 2;
break;
case 8:
case "end":
return _context12.stop();
}
}
}, _callee12, this);
}));
function onTrainEnd(_x20) {
return _onTrainEnd2.apply(this, arguments);
}
return onTrainEnd;
}();
return CallbackList;
}();
/**
* Callback that accumulates epoch averages of metrics.
*
* This callback is automatically applied to every LayersModel.
*/
var BaseLogger = /*#__PURE__*/function (_BaseCallback) {
_inheritsLoose(BaseLogger, _BaseCallback);
function BaseLogger() {
return _BaseCallback.call(this) || this;
}
var _proto3 = BaseLogger.prototype;
_proto3.onEpochBegin = /*#__PURE__*/function () {
var _onEpochBegin3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee13(epoch) {
return regeneratorRuntime.wrap(function _callee13$(_context13) {
while (1) {
switch (_context13.prev = _context13.next) {
case 0:
this.seen = 0;
this.totals = {};
case 2:
case "end":
return _context13.stop();
}
}
}, _callee13, this);
}));
function onEpochBegin(_x21) {
return _onEpochBegin3.apply(this, arguments);
}
return onEpochBegin;
}();
_proto3.onBatchEnd = /*#__PURE__*/function () {
var _onBatchEnd3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee14(batch, logs) {
var _this = this;
var batchSize, _loop, key;
return regeneratorRuntime.wrap(function _callee14$(_context14) {
while (1) {
switch (_context14.prev = _context14.next) {
case 0:
if (logs == null) {
logs = {};
}
batchSize = logs['size'] == null ? 0 : logs['size'];
this.seen += batchSize;
_loop = function _loop(key) {
var value = logs[key];
if (typeof value === 'number') {
if (!_this.totals.hasOwnProperty(key)) {
_this.totals[key] = 0;
}
_this.totals[key] = _this.totals[key] + value * batchSize;
} else {
var oldTotalsToDispose;
if (key in _this.totals) {
oldTotalsToDispose = _this.totals[key];
} else {
_this.totals[key] = 0;
}
var total = tidy(function () {
return add$1(_this.totals[key], mul(value, batchSize));
});
_this.totals[key] = total;
if (oldTotalsToDispose != null) {
oldTotalsToDispose.dispose();
}
}
};
for (key in logs) {
_loop(key);
}
case 5:
case "end":
return _context14.stop();
}
}
}, _callee14, this);
}));
function onBatchEnd(_x22, _x23) {
return _onBatchEnd3.apply(this, arguments);
}
return onBatchEnd;
}();
_proto3.onEpochEnd = /*#__PURE__*/function () {
var _onEpochEnd3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee15(epoch, logs) {
var _this2 = this;
var _loop2, _iterator9, _step9, _ret;
return regeneratorRuntime.wrap(function _callee15$(_context15) {
while (1) {
switch (_context15.prev = _context15.next) {
case 0:
if (!(logs != null)) {
_context15.next = 9;
break;
}
_loop2 = function _loop2() {
var key = _step9.value;
if (_this2.totals[key] == null) {
return "continue";
}
if (typeof _this2.totals[key] === 'number') {
logs[key] = _this2.totals[key] / _this2.seen;
} else {
tidy(function () {
var log = mul(div(1, _this2.seen), _this2.totals[key]);
logs[key] = log;
_this2.totals[key].dispose();
keep(logs[key]);
});
}
};
_iterator9 = _createForOfIteratorHelperLoose(this.params['metrics']);
case 3:
if ((_step9 = _iterator9()).done) {
_context15.next = 9;
break;
}
_ret = _loop2();
if (!(_ret === "continue")) {
_context15.next = 7;
break;
}
return _context15.abrupt("continue", 7);
case 7:
_context15.next = 3;
break;
case 9:
case "end":
return _context15.stop();
}
}
}, _callee15, this);
}));
function onEpochEnd(_x24, _x25) {
return _onEpochEnd3.apply(this, arguments);
}
return onEpochEnd;
}();
return BaseLogger;
}(BaseCallback);
/**
* Callback that records events into a `History` object. This callback is
* automatically applied to every TF.js Layers model. The `History` object
* gets returned by the `fit` method of models.
*/
var History = /*#__PURE__*/function (_BaseCallback2) {
_inheritsLoose(History, _BaseCallback2);
function History() {
return _BaseCallback2.apply(this, arguments) || this;
}
var _proto4 = History.prototype;
_proto4.onTrainBegin = /*#__PURE__*/function () {
var _onTrainBegin3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee16(logs) {
return regeneratorRuntime.wrap(function _callee16$(_context16) {
while (1) {
switch (_context16.prev = _context16.next) {
case 0:
this.epoch = [];
this.history = {};
case 2:
case "end":
return _context16.stop();
}
}
}, _callee16, this);
}));
function onTrainBegin(_x26) {
return _onTrainBegin3.apply(this, arguments);
}
return onTrainBegin;
}();
_proto4.onEpochEnd = /*#__PURE__*/function () {
var _onEpochEnd4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee17(epoch, logs) {
var key;
return regeneratorRuntime.wrap(function _callee17$(_context17) {
while (1) {
switch (_context17.prev = _context17.next) {
case 0:
if (logs == null) {
logs = {};
}
this.epoch.push(epoch);
for (key in logs) {
if (this.history[key] == null) {
this.history[key] = [];
}
this.history[key].push(logs[key]);
}
case 3:
case "end":
return _context17.stop();
}
}
}, _callee17, this);
}));
function onEpochEnd(_x27, _x28) {
return _onEpochEnd4.apply(this, arguments);
}
return onEpochEnd;
}()
/**
* Await the values of all losses and metrics.
*/
;
_proto4.syncData =
/*#__PURE__*/
function () {
var _syncData = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee18() {
var promises, keys, indices, key, valueArray, i, valueScalar, values, n, tensorToDispose;
return regeneratorRuntime.wrap(function _callee18$(_context18) {
while (1) {
switch (_context18.prev = _context18.next) {
case 0:
promises = [];
keys = [];
indices = [];
for (key in this.history) {
valueArray = this.history[key];
for (i = 0; i < valueArray.length; ++i) {
if (typeof valueArray[i] !== 'number') {
valueScalar = valueArray[i];
promises.push(valueScalar.data());
keys.push(key);
indices.push(i);
}
}
}
_context18.next = 6;
return Promise.all(promises);
case 6:
values = _context18.sent;
for (n = 0; n < values.length; ++n) {
tensorToDispose = this.history[keys[n]][indices[n]];
tensorToDispose.dispose();
this.history[keys[n]][indices[n]] = values[n][0];
}
case 8:
case "end":
return _context18.stop();
}
}
}, _callee18, this);
}));
function syncData() {
return _syncData.apply(this, arguments);
}
return syncData;
}();
return History;
}(BaseCallback);
/**
* Custom callback for training.
*/
var CustomCallback = /*#__PURE__*/function (_BaseCallback3) {
_inheritsLoose(CustomCallback, _BaseCallback3);
function CustomCallback(args, yieldEvery) {
var _this3;
_this3 = _BaseCallback3.call(this) || this;
_this3.currentEpoch = 0;
_this3.yieldEvery = yieldEvery || 'auto';
if (_this3.yieldEvery === 'auto') {
_this3.yieldEvery = DEFAULT_YIELD_EVERY_MS;
}
if (_this3.yieldEvery === 'never' && args.onYield != null) {
throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' + 'Either change `yieldEvery` or remove the callback');
}
if (isNumber(_this3.yieldEvery)) {
// Decorate `maybeWait` so it will be called at most once every
// `yieldEvery` ms.
_this3.maybeWait = debounce(_this3.maybeWait.bind(_assertThisInitialized(_this3)), _this3.yieldEvery);
}
_this3.trainBegin = args.onTrainBegin;
_this3.trainEnd = args.onTrainEnd;
_this3.epochBegin = args.onEpochBegin;
_this3.epochEnd = args.onEpochEnd;
_this3.batchBegin = args.onBatchBegin;
_this3.batchEnd = args.onBatchEnd;
_this3.yield = args.onYield;
return _this3;
}
var _proto5 = CustomCallback.prototype;
_proto5.maybeWait = /*#__PURE__*/function () {
var _maybeWait = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee19(epoch, batch, logs) {
var ps;
return regeneratorRuntime.wrap(function _callee19$(_context19) {
while (1) {
switch (_context19.prev = _context19.next) {
case 0:
ps = [];
if (!(this.yield != null)) {
_context19.next = 5;
break;
}
_context19.next = 4;
return resolveScalarsInLogs(logs);
case 4:
ps.push(this.yield(epoch, batch, logs));
case 5:
ps.push(nextFrame());
_context19.next = 8;
return Promise.all(ps);
case 8:
case "end":
return _context19.stop();
}
}
}, _callee19, this);
}));
function maybeWait(_x29, _x30, _x31) {
return _maybeWait.apply(this, arguments);
}
return maybeWait;
}();
_proto5.onEpochBegin = /*#__PURE__*/function () {
var _onEpochBegin4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee20(epoch, logs) {
return regeneratorRuntime.wrap(function _callee20$(_context20) {
while (1) {
switch (_context20.prev = _context20.next) {
case 0:
this.currentEpoch = epoch;
if (!(this.epochBegin != null)) {
_context20.next = 6;
break;
}
_context20.next = 4;
return resolveScalarsInLogs(logs);
case 4:
_context20.next = 6;
return this.epochBegin(epoch, logs);
case 6:
case "end":
return _context20.stop();
}
}
}, _callee20, this);
}));
function onEpochBegin(_x32, _x33) {
return _onEpochBegin4.apply(this, arguments);
}
return onEpochBegin;
}();
_proto5.onEpochEnd = /*#__PURE__*/function () {
var _onEpochEnd5 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee21(epoch, logs) {
var ps;
return regeneratorRuntime.wrap(function _callee21$(_context21) {
while (1) {
switch (_context21.prev = _context21.next) {
case 0:
ps = [];
if (!(this.epochEnd != null)) {
_context21.next = 5;
break;
}
_context21.next = 4;
return resolveScalarsInLogs(logs);
case 4:
ps.push(this.epochEnd(epoch, logs));
case 5:
if (this.yieldEvery === 'epoch') {
ps.push(nextFrame());
}
_context21.next = 8;
return Promise.all(ps);
case 8:
case "end":
return _context21.stop();
}
}
}, _callee21, this);
}));
function onEpochEnd(_x34, _x35) {
return _onEpochEnd5.apply(this, arguments);
}
return onEpochEnd;
}();
_proto5.onBatchBegin = /*#__PURE__*/function () {
var _onBatchBegin3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee22(batch, logs) {
return regeneratorRuntime.wrap(function _callee22$(_context22) {
while (1) {
switch (_context22.prev = _context22.next) {
case 0:
if (!(this.batchBegin != null)) {
_context22.next = 5;
break;
}
_context22.next = 3;
return resolveScalarsInLogs(logs);
case 3:
_context22.next = 5;
return this.batchBegin(batch, logs);
case 5:
case "end":
return _context22.stop();
}
}
}, _callee22, this);
}));
function onBatchBegin(_x36, _x37) {
return _onBatchBegin3.apply(this, arguments);
}
return onBatchBegin;
}();
_proto5.onBatchEnd = /*#__PURE__*/function () {
var _onBatchEnd4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee23(batch, logs) {
var ps;
return regeneratorRuntime.wrap(function _callee23$(_context23) {
while (1) {
switch (_context23.prev = _context23.next) {
case 0:
ps = [];
if (!(this.batchEnd != null)) {
_context23.next = 5;
break;
}
_context23.next = 4;
return resolveScalarsInLogs(logs);
case 4:
ps.push(this.batchEnd(batch, logs));
case 5:
if (this.yieldEvery === 'batch') {
ps.push(nextFrame());
} else if (isNumber(this.yieldEvery)) {
ps.push(this.maybeWait(this.currentEpoch, batch, logs));
}
_context23.next = 8;
return Promise.all(ps);
case 8:
case "end":
return _context23.stop();
}
}
}, _callee23, this);
}));
function onBatchEnd(_x38, _x39) {
return _onBatchEnd4.apply(this, arguments);
}
return onBatchEnd;
}();
_proto5.onTrainBegin = /*#__PURE__*/function () {
var _onTrainBegin4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee24(logs) {
return regeneratorRuntime.wrap(function _callee24$(_context24) {
while (1) {
switch (_context24.prev = _context24.next) {
case 0:
if (!(this.trainBegin != null)) {
_context24.next = 5;
break;
}
_context24.next = 3;
return resolveScalarsInLogs(logs);
case 3:
_context24.next = 5;
return this.trainBegin(logs);
case 5:
case "end":
return _context24.stop();
}
}
}, _callee24, this);
}));
function onTrainBegin(_x40) {
return _onTrainBegin4.apply(this, arguments);
}
return onTrainBegin;
}();
_proto5.onTrainEnd = /*#__PURE__*/function () {
var _onTrainEnd3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee25(logs) {
return regeneratorRuntime.wrap(function _callee25$(_context25) {
while (1) {
switch (_context25.prev = _context25.next) {
case 0:
if (!(this.trainEnd != null)) {
_context25.next = 5;
break;
}
_context25.next = 3;
return resolveScalarsInLogs(logs);
case 3:
_context25.next = 5;
return this.trainEnd(logs);
case 5:
case "end":
return _context25.stop();
}
}
}, _callee25, this);
}));
function onTrainEnd(_x41) {
return _onTrainEnd3.apply(this, arguments);
}
return onTrainEnd;
}();
return CustomCallback;
}(BaseCallback);
/**
* Standardize callbacks or configurations of them to an Array of callbacks.
*/
function standardizeCallbacks(callbacks, yieldEvery) {
if (callbacks == null) {
callbacks = {};
}
if (callbacks instanceof BaseCallback) {
return [callbacks];
}
if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
return callbacks;
} // Convert custom callback configs to custom callback objects.
var callbackConfigs = toList(callbacks);
return callbackConfigs.map(function (callbackConfig) {
return new CustomCallback(callbackConfig, yieldEvery);
});
}
/**
* A global registry for callback constructors to be used during
* LayersModel.fit().
*/
var CallbackConstructorRegistry = /*#__PURE__*/function () {
/**
* Blocks public access to constructor.
*/
function CallbackConstructorRegistry() {}
/**
* Register a tf.LayersModel.fit() callback constructor.
*
* The registered callback constructor will be used to instantiate
* callbacks for every tf.LayersModel.fit() call afterwards.
*
* @param verbosityLevel Level of verbosity at which the `callbackConstructor`
* is to be reigstered.
* @param callbackConstructor A no-arg constructor for `tf.Callback`.
* @throws Error, if the same callbackConstructor has been registered before,
* either at the same or a different `verbosityLevel`.
*/
CallbackConstructorRegistry.registerCallbackConstructor = function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
assert(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), function () {
return "Verbosity level is expected to be an integer >= 0, " + ("but got " + verbosityLevel);
});
CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
CallbackConstructorRegistry.constructors[verbosityLevel] = [];
}
CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
};
CallbackConstructorRegistry.checkForDuplicate = function checkForDuplicate(callbackConstructor) {
for (var levelName in CallbackConstructorRegistry.constructors) {
var constructors = CallbackConstructorRegistry.constructors[+levelName];
constructors.forEach(function (ctor) {
if (ctor === callbackConstructor) {
throw new ValueError('Duplicate callback constructor.');
}
});
}
}
/**
* Clear all registered callback constructors.
*/
;
CallbackConstructorRegistry.clear = function clear() {
CallbackConstructorRegistry.constructors = {};
}
/**
* Create callbacks using the registered callback constructors.
*
* Given `verbosityLevel`, all constructors registered at that level or above
* will be called and the instantiated callbacks will be used.
*
* @param verbosityLevel: Level of verbosity.
*/
;
CallbackConstructorRegistry.createCallbacks = function createCallbacks(verbosityLevel) {
var constructors = [];
for (var levelName in CallbackConstructorRegistry.constructors) {
var level = +levelName;
if (verbosityLevel >= level) {
constructors.push.apply(constructors, CallbackConstructorRegistry.constructors[level]);
}
}
return constructors.map(function (ctor) {
return new ctor();
});
};
return CallbackConstructorRegistry;
}();
CallbackConstructorRegistry.constructors = {};
function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
var history = new History();
var actualCallbacks = [new BaseLogger()].concat(CallbackConstructorRegistry.createCallbacks(verbose));
if (callbacks != null) {
actualCallbacks.push.apply(actualCallbacks, callbacks);
}
actualCallbacks.push(history);
var callbackList = new CallbackList(actualCallbacks); // TODO(cais): Figure out when this LayersModel instance can have a
// dynamically
// set property called 'callback_model' as in PyKeras.
callbackList.setParams({
epochs: epochs,
initialEpoch: initialEpoch,
samples: numTrainSamples,
steps: stepsPerEpoch,
batchSize: batchSize,
verbose: verbose,
doValidation: doValidation,
metrics: callbackMetrics
});
return {
callbackList: callbackList,
history: history
};
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Instantiate a layer from a config dictionary.
* @param config dict of the form {class_name: str, config: dict}
* @param customObjects dict mapping class names (or function names)
* of custom (non-Keras) objects to class/functions
* @param fastWeightInit Optional flag to use fast weight initialization
* during deserialization. This is applicable to cases in which
* the initialization will be immediately overwritten by loaded weight
* values. Default: `false`.
* @returns Layer instance (may be LayersModel, Sequential, Layer...)
*/
function deserialize$1(config, customObjects, fastWeightInit) {
if (customObjects === void 0) {
customObjects = {};
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Normalizes a tensor wrt the L2 norm alongside the specified axis.
* @param x
* @param axis Axis along which to perform normalization.
*/
function l2Normalize(x, axis) {
return tidy(function () {
if (x.dtype !== 'float32') {
x = cast(x, 'float32');
}
var squareSum = sum$1(square$1(x), axis, true);
var epsilonTensor = fill(squareSum.shape, epsilon());
var norm = sqrt$3(maximum(squareSum, epsilonTensor));
return div(x, norm);
});
}
function meanSquaredError$1(yTrue, yPred) {
return tidy(function () {
return mean(square$1(sub(yPred, yTrue)), -1);
});
}
function meanAbsoluteError(yTrue, yPred) {
return tidy(function () {
return mean(abs$8(sub(yPred, yTrue)), -1);
});
}
function meanAbsolutePercentageError(yTrue, yPred) {
return tidy(function () {
var diff = sub(yTrue, yPred);
var clippedTrue = clipByValue(abs$8(yTrue), epsilon(), Number.MAX_VALUE);
var absResult = abs$8(div(diff, clippedTrue));
return mul(100, mean(absResult, -1));
});
}
function meanSquaredLogarithmicError(yTrue, yPred) {
return tidy(function () {
var clippedPred = clipByValue(yPred, epsilon(), Number.MAX_VALUE);
var firstLog = log$a(add$1(1, clippedPred));
var clippedTrue = clipByValue(yTrue, epsilon(), Number.MAX_VALUE);
var secondLog = log$a(add$1(1, clippedTrue));
return mean(square$1(sub(firstLog, secondLog)), -1);
});
}
function squaredHinge(yTrue, yPred) {
return tidy(function () {
var maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
return mean(square$1(maxResult), -1);
});
}
function hinge(yTrue, yPred) {
return tidy(function () {
var maxResult = maximum(0, sub(1, mul(yTrue, yPred)));
return mean(maxResult, -1);
});
}
function categoricalHinge(yTrue, yPred) {
return tidy(function () {
var pos = sum$1(mul(yTrue, yPred), -1);
var neg = max$5(mul(sub(1, yTrue), yPred), -1);
return maximum(0, add$1(1, sub(neg, pos)));
});
}
/**
* Logarithm of the hyperbolic cosine of the prediction error.
*
* `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
* to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
* like the mean squared error, but will not be so strongly affected by the
* occasional wildly incorrect prediction.
*/
function logcosh(yTrue, yPred) {
return tidy(function () {
var log2 = Math.log(2);
var predictionDiff = sub(yPred, yTrue);
var logcoshResult = sub(add$1(predictionDiff, softplus(mul(-2, predictionDiff))), log2);
return mean(logcoshResult, -1);
});
}
function categoricalCrossentropy(target, output, fromLogits) {
if (fromLogits === void 0) {
fromLogits = false;
}
return tidy(function () {
if (fromLogits) {
output = softmax(output);
} else {
// scale preds so that the class probabilities of each sample sum to 1.
var outputSum = sum$1(output, output.shape.length - 1, true);
output = div(output, outputSum);
}
output = clipByValue(output, epsilon(), 1 - epsilon());
return neg(sum$1(mul(cast(target, 'float32'), log$a(output)), output.shape.length - 1));
});
}
/**
* Categorical crossentropy with integer targets.
*
* @param target An integer tensor.
* @param output A tensor resulting from a softmax (unless `fromLogits` is
* `true`, in which case `output` is expected to be the logits).
* @param fromLogits Boolean, whether `output` is the result of a softmax, or is
* a tensor of logits.
*/
function sparseCategoricalCrossentropy(target, output, fromLogits) {
if (fromLogits === void 0) {
fromLogits = false;
}
return tidy(function () {
var flatTarget = cast(floor$a(flatten$1(target)), 'int32');
output = clipByValue(output, epsilon(), 1 - epsilon());
var outputShape = output.shape;
var oneHotTarget = reshape(oneHot(flatTarget, outputShape[outputShape.length - 1]), outputShape);
return categoricalCrossentropy(oneHotTarget, output, fromLogits);
});
}
/**
* From TensorFlow's implementation in nn_impl.py:
*
* For brevity, let `x = logits`, `z = labels`. The logistic loss is
* z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
* = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
* = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
* = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
* = (1 - z) * x + log(1 + exp(-x))
* = x - x * z + log(1 + exp(-x))
* For x < 0, to avoid overflow in exp(-x), we reformulate the above
* x - x * z + log(1 + exp(-x))
* = log(exp(x)) - x * z + log(1 + exp(-x))
* = - x * z + log(1 + exp(x))
* Hence, to ensure stability and avoid overflow, the implementation uses this
* equivalent formulation
* max(x, 0) - x * z + log(1 + exp(-abs(x)))
*
* @param labels The labels.
* @param logits The logits.
*/
function sigmoidCrossEntropyWithLogits(labels, logits) {
if (!arraysEqual(labels.shape, logits.shape)) {
throw new ValueError("logits and labels must have the same shape, but got shapes " + (JSON.stringify(labels.shape) + " and " + JSON.stringify(logits.shape)));
}
return tidy(function () {
// The logistic loss formula from above is
// x - x * z + log(1 + exp(-x))
// For x < 0, a more numerically stable formula is
// -x * z + log(1 + exp(x))
// Note that these two expressions can be combined into the following:
// max(x, 0) - x * z + log(1 + exp(-abs(x)))
var reluLogits = relu(logits);
var negAbsLogits = neg(abs$8(logits));
return add$1(sub(reluLogits, mul(logits, labels)), log1p(exp$3(negAbsLogits)));
});
}
function binaryCrossentropy(yTrue, yPred) {
return tidy(function () {
var y;
y = clipByValue(yPred, epsilon(), 1 - epsilon());
y = log$a(div(y, sub(1, y)));
return mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
});
}
function kullbackLeiblerDivergence(yTrue, yPred) {
return tidy(function () {
var clippedTrue = clipByValue(yTrue, epsilon(), 1);
var clippedPred = clipByValue(yPred, epsilon(), 1);
return sum$1(mul(yTrue, log$a(div(clippedTrue, clippedPred))), -1);
});
}
function poisson(yTrue, yPred) {
return tidy(function () {
var logPred = log$a(add$1(epsilon(), yPred));
return mean(sub(yPred, mul(yTrue, logPred)), -1);
});
}
function cosineProximity(yTrue, yPred) {
return tidy(function () {
var trueNormalized = l2Normalize(yTrue, -1);
var predNormalized = l2Normalize(yPred, -1);
var trueXPred = mul(trueNormalized, predNormalized);
return neg(sum$1(trueXPred, -1));
});
}
var mse = meanSquaredError$1;
var MSE = meanSquaredError$1;
var mae = meanAbsoluteError;
var MAE = meanAbsoluteError;
var mape = meanAbsolutePercentageError;
var MAPE = meanAbsolutePercentageError;
var msle = meanSquaredLogarithmicError;
var MSLE = meanSquaredLogarithmicError;
var kld = kullbackLeiblerDivergence;
var KLD = kullbackLeiblerDivergence;
var cosine = cosineProximity; // TODO(michaelterry): Add deserialize() function.
var lossesMap = {
meanSquaredError: meanSquaredError$1,
meanAbsoluteError: meanAbsoluteError,
meanAbsolutePercentageError: meanAbsolutePercentageError,
meanSquaredLogarithmicError: meanSquaredLogarithmicError,
squaredHinge: squaredHinge,
hinge: hinge,
categoricalHinge: categoricalHinge,
logcosh: logcosh,
categoricalCrossentropy: categoricalCrossentropy,
sparseCategoricalCrossentropy: sparseCategoricalCrossentropy,
binaryCrossentropy: binaryCrossentropy,
kullbackLeiblerDivergence: kullbackLeiblerDivergence,
poisson: poisson,
cosineProximity: cosineProximity
}; // Porting note: This diverges from the PyKeras implementation and may need to
// change based on (de)serialization requirements.
function get$3(identifierOrFn) {
if (typeof identifierOrFn === 'string') {
if (identifierOrFn in lossesMap) {
return lossesMap[identifierOrFn];
}
var errMsg = "Unknown loss " + identifierOrFn;
if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
errMsg = "Unknown loss " + identifierOrFn + ". " + 'Use "categoricalCrossentropy" as the string name for ' + 'tf.losses.softmaxCrossEntropy';
}
throw new ValueError(errMsg);
} else {
return identifierOrFn;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
function binaryAccuracy(yTrue, yPred) {
return tidy(function () {
var threshold = mul(.5, onesLike(yPred));
var yPredThresholded = cast$1(greater(yPred, threshold), yTrue.dtype);
return mean(equal(yTrue, yPredThresholded), -1);
});
}
function categoricalAccuracy(yTrue, yPred) {
return tidy(function () {
return cast$1(equal(argMax(yTrue, -1), argMax(yPred, -1)), 'float32');
});
}
function truePositives(yTrue, yPred) {
return tidy(function () {
return cast(sum$1(logicalAnd(equal(yTrue, 1), equal(yPred, 1))), 'float32');
});
}
function falseNegatives(yTrue, yPred) {
return tidy(function () {
return cast(sum$1(logicalAnd(equal(yTrue, 1), equal(yPred, 0))), 'float32');
});
}
function falsePositives(yTrue, yPred) {
return tidy(function () {
return cast(sum$1(logicalAnd(equal(yTrue, 0), equal(yPred, 1))), 'float32');
});
}
function precision(yTrue, yPred) {
return tidy(function () {
var tp = truePositives(yTrue, yPred);
var fp = falsePositives(yTrue, yPred);
var denominator = add$1(tp, fp);
return cast(where(greater(denominator, 0), div(tp, denominator), 0), 'float32');
});
}
function recall(yTrue, yPred) {
return tidy(function () {
var tp = truePositives(yTrue, yPred);
var fn = falseNegatives(yTrue, yPred);
var denominator = add$1(tp, fn);
return cast(where(greater(denominator, 0), div(tp, denominator), 0), 'float32');
});
}
function binaryCrossentropy$1(yTrue, yPred) {
return binaryCrossentropy(yTrue, yPred);
}
function sparseCategoricalAccuracy(yTrue, yPred) {
if (yTrue.rank === yPred.rank) {
yTrue = squeeze(yTrue, [yTrue.rank - 1]);
}
yPred = argMax(yPred, -1);
if (yPred.dtype !== yTrue.dtype) {
yPred = cast(yPred, yTrue.dtype);
}
return cast(equal(yTrue, yPred), 'float32');
}
function topKCategoricalAccuracy(yTrue, yPred) {
throw new NotImplementedError();
}
function sparseTopKCategoricalAccuracy(yTrue, yPred) {
throw new NotImplementedError();
} // Aliases.
var mse$1 = meanSquaredError$1;
var MSE$1 = meanSquaredError$1;
var mae$1 = meanAbsoluteError;
var MAE$1 = meanAbsoluteError;
var mape$1 = meanAbsolutePercentageError;
var MAPE$1 = meanAbsolutePercentageError;
var categoricalCrossentropy$1 = categoricalCrossentropy;
var cosine$1 = cosineProximity;
var sparseCategoricalCrossentropy$1 = sparseCategoricalCrossentropy; // TODO(cais, nielsene): Add serialize().
var metricsMap = {
binaryAccuracy: binaryAccuracy,
categoricalAccuracy: categoricalAccuracy,
precision: precision,
categoricalCrossentropy: categoricalCrossentropy$1,
sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
mse: mse$1,
MSE: MSE$1,
mae: mae$1,
MAE: MAE$1,
mape: mape$1,
MAPE: MAPE$1,
cosine: cosine$1
};
function get$4(identifier) {
if (typeof identifier === 'string' && identifier in metricsMap) {
return metricsMap[identifier];
} else if (typeof identifier !== 'string' && identifier != null) {
return identifier;
} else {
throw new ValueError("Unknown metric " + identifier);
}
}
/**
* Get the shortcut function name.
*
* If the fn name is a string,
* directly return the string name.
* If the function is included in metricsMap or lossesMap,
* return key of the map.
* - If the function relative to multiple keys,
* return the first found key as the function name.
* - If the function exists in both lossesMap and metricsMap,
* search lossesMap first.
* If the function is not included in metricsMap or lossesMap,
* return the function name.
*
* @param fn loss function, metric function, or short cut name.
* @returns Loss or Metric name in string.
*/
function getLossOrMetricName(fn) {
assert$1(fn !== null, "Unknown LossOrMetricFn " + fn);
if (typeof fn === 'string') {
return fn;
} else {
var fnName;
for (var _i = 0, _Object$keys = Object.keys(lossesMap); _i < _Object$keys.length; _i++) {
var key = _Object$keys[_i];
if (lossesMap[key] === fn) {
fnName = key;
break;
}
}
if (fnName !== undefined) {
return fnName;
}
for (var _i2 = 0, _Object$keys2 = Object.keys(metricsMap); _i2 < _Object$keys2.length; _i2++) {
var _key = _Object$keys2[_i2];
if (metricsMap[_key] === fn) {
fnName = _key;
break;
}
}
if (fnName !== undefined) {
return fnName;
}
return fn.name;
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// Porting note: This diverges from the PyKeras implementation and may need to
// change based on (de)serialization requirements.
function getOptimizer(identifier) {
var optimizerMap = {
'Adagrad': function Adagrad() {
return train.adagrad(0.01);
},
'Adadelta': function Adadelta() {
return train.adadelta(1, 0.95, epsilon());
},
'Adam': function Adam() {
return train.adam(0.001, 0.9, 0.999, epsilon());
},
'Adamax': function Adamax() {
return train.adamax(0.002, 0.9, 0.999, epsilon(), 0);
},
'RMSProp': function RMSProp() {
return train.rmsprop(0.001, 0.9, 0, epsilon());
},
'SGD': function SGD() {
return train.sgd(0.01);
}
};
optimizerMap['adagrad'] = optimizerMap['Adagrad'];
optimizerMap['adadelta'] = optimizerMap['Adadelta'];
optimizerMap['adam'] = optimizerMap['Adam'];
optimizerMap['adamax'] = optimizerMap['Adamax'];
optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
optimizerMap['sgd'] = optimizerMap['SGD'];
if (identifier in optimizerMap) {
return optimizerMap[identifier]();
}
throw new ValueError("Unknown Optimizer " + identifier);
}
/**
* @license
* Copyright 2019 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/** Utility functions related to user-defined metadata. */
// Maximum recommended serialized size for user-defined metadata.
// Beyond this limit, a warning message will be printed during model loading and
// saving.
var MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
/**
* Check validity of user-defined metadata.
*
* @param userDefinedMetadata
* @param modelName Name of the model that the user-defined metadata belongs to.
* Used during construction of error messages.
* @param checkSize Whether to check the size of the metadata is under
* recommended limit. Default: `false`. If `true`, will try stringify the
* JSON object and print a console warning if the serialzied size is above the
* limit.
* @throws Error if `userDefinedMetadata` is not a plain JSON object.
*/
function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize) {
if (checkSize === void 0) {
checkSize = false;
}
if (userDefinedMetadata == null || typeof userDefinedMetadata !== 'object' || Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || !plainObjectCheck(userDefinedMetadata)) {
throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
}
if (checkSize) {
var out = JSON.stringify(userDefinedMetadata);
if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
console.warn("User-defined metadata of model \"" + modelName + "\" is too large in " + ("size (length=" + out.length + " when serialized). It is not ") + "recommended to store such large objects in user-defined metadata. " + "Please make sure its serialized length is <= " + (MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH + "."));
}
}
}
/**
* Check if an input is plain JSON object or any valid subfield of it.
*
* @param x The input to be checked.
* @param assertObject Whether to assert `x` is a JSON object, i.e., reject
* cases of arrays and primitives.
* @return Returns `true` if and only if `x` is a plain JSON object,
* a JSON-valid primitive including string, number, boolean and null,
* or an array of the said types.
*/
// tslint:disable-next-line:no-any
function plainObjectCheck(x) {
if (x === null) {
// Note: typeof `null` is 'object', and `null` is valid in JSON.
return true;
} else if (typeof x === 'object') {
if (Object.getPrototypeOf(x) === Object.prototype) {
// `x` is a JavaScript object and its prototype is Object.
var keys = Object.keys(x);
for (var _i = 0, _keys = keys; _i < _keys.length; _i++) {
var key = _keys[_i];
if (typeof key !== 'string') {
// JSON keys must be strings.
return false;
}
if (!plainObjectCheck(x[key])) {
// Recursive call.
return false;
}
}
return true;
} else {
// `x` is a JavaScript object but its prototype is not Object.
if (Array.isArray(x)) {
// `x` is a JavaScript array.
for (var _iterator = _createForOfIteratorHelperLoose(x), _step; !(_step = _iterator()).done;) {
var item = _step.value;
if (!plainObjectCheck(item)) {
// Recursive call.
return false;
}
}
return true;
} else {
// `x` is a JavaScript object and its prototype is not Object,
// and it's not an Array. I.e., it's a complex object such as
// `Error` and `Date`.
return false;
}
}
} else {
// `x` is not a JavaScript object or `null`.
var xType = typeof x;
return xType === 'string' || xType === 'number' || xType === 'boolean';
}
}
/**
* Print the summary of a LayersModel object.
*
* @param model tf.LayersModel instance.
* @param lineLength Total length of printed lines. Set this to adapt to the
* display to different terminal or console sizes.
* @param positions Relative or absolute positions of log elements in each
* line. Each number corresponds to right-most (i.e., ending) position of a
* column.
* If not provided, defaults to `[0.45, 0.85, 1]` for sequential-like
* models and `[0.33, 0.55, 0.67, 1]` for non-sequential like models.
* @param printFn Print function to use.
* It will be called on each line of the summary. You can provide a custom
* function in order to capture the string summary. Defaults to `console.log`.
*/
function printSummary(model, lineLength, positions, // tslint:disable-next-line:no-any
printFn) {
if (printFn === void 0) {
printFn = console.log;
}
var sequentialLike = isModelSequentialLike(model); // Header names for different log elements.
var toDisplay = ['Layer (type)', 'Output shape', 'Param #'];
if (sequentialLike) {
lineLength = lineLength || 65;
positions = positions || [0.45, 0.85, 1];
} else {
lineLength = lineLength || 98;
positions = positions || [0.33, 0.55, 0.67, 1]; // Header names for different log elements.
}
if (positions[positions.length - 1] <= 1) {
// `positions` is relative. Convert it to absolute positioning.
positions = positions.map(function (p) {
return Math.floor(lineLength * p);
});
}
var relevantNodes;
if (!sequentialLike) {
toDisplay.push('Receives inputs');
relevantNodes = [];
for (var depth in model.nodesByDepth) {
var _relevantNodes;
(_relevantNodes = relevantNodes).push.apply(_relevantNodes, model.nodesByDepth[depth]);
}
}
printFn('_'.repeat(lineLength));
printRow(toDisplay, positions, printFn);
printFn('='.repeat(lineLength));
var layers = model.layers;
for (var i = 0; i < layers.length; ++i) {
if (sequentialLike) {
printLayerSummary(layers[i], positions, printFn);
} else {
printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
}
printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
} // tslint:disable-next-line:no-any
model.checkTrainableWeightsConsistency();
var trainableCount = countTrainableParams(model);
var nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
printFn("Total params: " + (trainableCount + nonTrainableCount));
printFn("Trainable params: " + trainableCount);
printFn("Non-trainable params: " + nonTrainableCount);
printFn('_'.repeat(lineLength));
}
function countTrainableParams(model) {
var trainableCount; // tslint:disable:no-any
if (model.collectedTrainableWeights != null) {
trainableCount = countParamsInWeights(model.collectedTrainableWeights);
} else {
trainableCount = countParamsInWeights(model.trainableWeights);
} // tslint:enable:no-any
return trainableCount;
}
function isModelSequentialLike(model) {
var sequentialLike = true;
var nodesByDepth = [];
var nodes = [];
for (var depth in model.nodesByDepth) {
nodesByDepth.push(model.nodesByDepth[depth]);
}
for (var _i = 0, _nodesByDepth = nodesByDepth; _i < _nodesByDepth.length; _i++) {
var depthNodes = _nodesByDepth[_i];
if (depthNodes.length > 1 || depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
sequentialLike = false;
break;
}
nodes.push.apply(nodes, depthNodes);
}
if (sequentialLike) {
// Search for shared layers.
for (var _iterator = _createForOfIteratorHelperLoose(model.layers), _step; !(_step = _iterator()).done;) {
var layer = _step.value;
var flag = false;
for (var _iterator2 = _createForOfIteratorHelperLoose(layer.inboundNodes), _step2; !(_step2 = _iterator2()).done;) {
var node = _step2.value;
if (nodes.indexOf(node) !== -1) {
if (flag) {
sequentialLike = false;
break;
} else {
flag = true;
}
}
}
if (!sequentialLike) {
break;
}
}
}
return sequentialLike;
}
function printRow(fields, positions, // tslint:disable-next-line:no-any
printFn) {
if (printFn === void 0) {
printFn = console.log;
}
var line = '';
for (var i = 0; i < fields.length; ++i) {
if (i > 0) {
line = line.slice(0, line.length - 1) + ' ';
}
line += fields[i];
line = line.slice(0, positions[i]);
line += ' '.repeat(positions[i] - line.length);
}
printFn(line);
}
/**
* Prints a summary for a single Layer, without connectivity information.
*
* @param layer: Layer instance to print.
*/
function printLayerSummary(layer, positions, // tslint:disable-next-line:no-any
printFn) {
var outputShape;
try {
outputShape = JSON.stringify(layer.outputShape);
} catch (err) {
outputShape = 'multiple';
}
var name = layer.name;
var className = layer.getClassName();
var fields = [name + " (" + className + ")", outputShape, layer.countParams().toString()];
printRow(fields, positions, printFn);
}
/**
* Prints a summary for a single Layer, with connectivity information.
*/
function printLayerSummaryWithConnections(layer, positions, relevantNodes, // tslint:disable-next-line:no-any
printFn) {
var outputShape;
try {
outputShape = JSON.stringify(layer.outputShape);
} catch (err) {
outputShape = 'multiple';
}
var connections = [];
for (var _iterator3 = _createForOfIteratorHelperLoose(layer.inboundNodes), _step3; !(_step3 = _iterator3()).done;) {
var node = _step3.value;
if (relevantNodes != null && relevantNodes.length > 0 && relevantNodes.indexOf(node) === -1) {
continue;
}
for (var _i2 = 0; _i2 < node.inboundLayers.length; ++_i2) {
var inboundLayer = node.inboundLayers[_i2].name;
var inboundLayerIndex = node.nodeIndices[_i2];
var inboundTensorIndex = node.tensorIndices[_i2];
connections.push(inboundLayer + "[" + inboundLayerIndex + "][" + inboundTensorIndex + "]");
}
}
var name = layer.name;
var className = layer.getClassName();
var firstConnection = connections.length === 0 ? '' : connections[0];
var fields = [name + " (" + className + ")", outputShape, layer.countParams().toString(), firstConnection];
printRow(fields, positions, printFn);
for (var i = 1; i < connections.length; ++i) {
printRow(['', '', '', connections[i]], positions, printFn);
}
}
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Test whether a value in an array is the name of a LayersModel or Layer.
* @param key The key name that the value is found under. Note that the key
* may not be at the level immediately above the value, if the value is in a
* nested array.
* @param index Index of the value in the Array that it is found in.
* @param value The value object.
* @returns A boolean indicating whether value is a name.
*/
function isArrayItemInputOrOutputName(key, index, value) {
return (key === 'inboundNodes' || key === 'outputLayers' || key === 'inputLayers') && index === 0 && typeof value === 'string';
}
/**
* Convert a Pythonic config object to TypeScript config object.
* @param pythonicConfig The config object to convert.
* @param key Optional key name of the object being converted.
* @returns Result of the conversion.
*/
function convertPythonicToTs(pythonicConfig, key) {
if (pythonicConfig === null) {
return null;
} else if (typeof pythonicConfig === 'string') {
return toCamelCase(pythonicConfig);
} else if (typeof pythonicConfig === 'number' || typeof pythonicConfig === 'boolean') {
return pythonicConfig;
} else if (pythonicConfig instanceof Array) {
var tsArray = [];
var arrayLength = pythonicConfig.length;
for (var i = 0; i < arrayLength; ++i) {
var item = pythonicConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
tsArray.push(item);
} else {
tsArray.push(convertPythonicToTs(item, key));
}
}
return tsArray;
} else {
var tsDict = {};
for (var _i = 0, _Object$keys = Object.keys(pythonicConfig); _i < _Object$keys.length; _i++) {
var pythonicKey = _Object$keys[_i];
var pythonicValue = pythonicConfig[pythonicKey];
if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
// Special case the 'name' key with a string value. Name values, such as
// the names of LayersModel and Layer instances, should not undergo the
// camel-case conversion.
tsDict[pythonicKey] = pythonicValue;
} else {
var tsKey = toCamelCase(pythonicKey);
tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
}
}
return tsDict;
}
}
/**
* Convert a TypeScript config object to Python config object.
* @param tsConfig The config object to convert.
* @param key Optional key name of the object being converted.
* @returns Result of the conversion.
*/
function convertTsToPythonic(tsConfig, key) {
if (tsConfig === null || tsConfig === undefined) {
return null;
} else if (typeof tsConfig === 'string') {
return toSnakeCase(tsConfig);
} else if (typeof tsConfig === 'number' || typeof tsConfig === 'boolean') {
return tsConfig;
} else if (tsConfig instanceof Array) {
var pyArray = [];
var arrayLength = tsConfig.length;
for (var i = 0; i < arrayLength; ++i) {
var item = tsConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
pyArray.push(item);
} else {
pyArray.push(convertTsToPythonic(item, key));
}
}
return pyArray;
} else {
var pyDict = {};
for (var _i2 = 0, _Object$keys2 = Object.keys(tsConfig); _i2 < _Object$keys2.length; _i2++) {
var tsKey = _Object$keys2[_i2];
var tsValue = tsConfig[tsKey];
var pyKey = toSnakeCase(tsKey);
if ((tsKey === 'name' || tsKey === 'className') && typeof tsValue === 'string') {
// Special case the 'name' key with a string value. Name values, such as
// the names of LayersModel and Layer instances, should not undergo the
// snake-case conversion.
pyDict[pyKey] = tsValue;
} else {
pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
}
}
return pyDict;
}
}
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$2 = '3.9.0';
/**
* Helper function to check the dtype and shape compatibility of a feed value.
*/
function assertFeedCompatibility(key, val) {
// Check dtype compatibility.
if (key.dtype == null || key.dtype === val.dtype) {
// a. If types match, return val tensor as is.
return val;
}
try {
// b. Attempt to convert to expected type.
return cast(val, key.dtype);
} catch (err) {
// c. If conversion fails, return helpful error.
throw new ValueError("The dtype of the feed (" + val.dtype + ") can not be cast to the dtype " + ("of the key '" + key.name + "' (" + key.dtype + ")."));
}
}
/**
* FeedDict: A mapping from unique SymbolicTensors to feed values for them.
* A feed value is a concrete value represented as an `Tensor`.
*/
var FeedDict = /*#__PURE__*/function () {
/**
* Constructor, optionally does copy-construction.
* @param feeds An Array of `Feed`s, or another `FeedDict`, in which case
* copy-construction will be performed.
*/
function FeedDict(feeds) {
this.id2Value = {};
this.id2Mask = {};
this.name2Id = {};
if (feeds instanceof FeedDict) {
for (var id in feeds.id2Value) {
this.id2Value[id] = feeds.id2Value[id];
if (id in feeds.id2Mask) {
this.id2Mask[id] = feeds.id2Mask[id];
}
}
} else {
if (feeds == null) {
return;
}
for (var _iterator = _createForOfIteratorHelperLoose(feeds), _step; !(_step = _iterator()).done;) {
var feed = _step.value;
this.add(feed.key, feed.value);
}
}
}
/**
* Add a key-value pair to the FeedDict.
*
* @param key The key of the feed.
* @param value The value of the tensor feed.
* @param mask The value of the mask feed (optional).
* @returns This `FeedDict`.
* @throws ValueError: If the key `SymbolicTensor` already exists in the
* `FeedDict`.
*/
var _proto = FeedDict.prototype;
_proto.add = function add(key, value, mask) {
if (this.id2Value[key.id] == null) {
this.id2Value[key.id] = assertFeedCompatibility(key, value);
this.name2Id[key.name] = key.id;
if (mask != null) {
this.id2Mask[key.id] = mask;
}
} else {
throw new ValueError("Duplicate key: name=" + key.name + ", id=" + key.id);
}
return this;
}
/**
* Add a Feed to the FeedDict.
* @param feed The new `Feed` to add.
* @returns This `FeedDict`.
*/
;
_proto.addFeed = function addFeed(feed) {
this.add(feed.key, feed.value);
}
/**
* Probe whether a key already exists in the FeedDict.
* @param key
*/
;
_proto.hasKey = function hasKey(key) {
return this.id2Value[key.id] != null;
}
/**
* Get all the SymbolicTensor available in this FeedDict.
*/
;
_proto.names = function names() {
return Object.keys(this.name2Id);
}
/**
* Get the feed value for given key.
* @param key The SymbolicTensor, or its name (as a string), of which the
* value is sought.
* @returns If `key` exists, the corresponding feed value.
* @throws ValueError: If `key` does not exist in this `FeedDict`.
*/
;
_proto.getValue = function getValue(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError("Nonexistent key: " + key.name);
} else {
return this.id2Value[key.id];
}
} else {
var id = this.name2Id[key];
if (id == null) {
throw new ValueError("Feed dict has no SymbolicTensor name: " + key);
}
return this.id2Value[id];
}
}
/**
* Get the feed mask for given key.
* @param key The SymbolicTensor, or its name (as a string), of which the
* value is sought.
* @returns If `key` exists, the corresponding feed mask.
* @throws ValueError: If `key` does not exist in this `FeedDict`.
*/
;
_proto.getMask = function getMask(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError("Nonexistent key: " + key.name);
} else {
return this.id2Mask[key.id];
}
} else {
var id = this.name2Id[key];
if (id == null) {
throw new ValueError("Feed dict has no SymbolicTensor name: " + key);
}
return this.id2Mask[id];
}
}
/** Dispose all mask Tensors held by this object. */
;
_proto.disposeMasks = function disposeMasks() {
if (this.id2Mask != null) {
dispose(this.id2Mask);
}
};
return FeedDict;
}(); // Cache for topologically sorted SymbolicTensors for given execution
// targets (i.e., fetches).
var cachedSorted = {}; // Cache for recipient count maps for given execution targets (i.e., fetches).
var cachedRecipientCounts = {};
/**
* Execute a SymbolicTensor by using concrete feed values.
*
* A `SymbolicTensor` object is a node in a computation graph of TF.js
* Layers. The object is backed by a source layer and input
* `SymbolicTensor`s to the source layer. This method evaluates
* the `call()` method of the source layer, using concrete values of the
* inputs obtained from either
* * `feedDict`, if the input key exists in `feedDict`, or else,
* * a recursive call to `execute()` itself.
*
* @param x: The `SymbolicTensor` to execute.
* @param feedDict: The feed values, as base condition of the recursion.
* execution.
* @param kwargs: Optional keyword arguments.
* @param probe: A probe object (of interface `ExecutionProbe`) used for
* testing memory footprint of `execute` calls.
* @returns Result of the execution.
* @throws ValueError: If any `SymbolicTensor`s from `InputLayer`s
* encountered during the execution lacks a feed value in `feedDict`.
*/
function execute(fetches, feedDict, kwargs, probe) {
var training = kwargs == null ? false : kwargs['training'];
var arrayFetches = Array.isArray(fetches);
var fetchArray = arrayFetches ? fetches : [fetches];
var outputNames = fetchArray.map(function (t) {
return t.name;
});
var finalOutputs = [];
var feedNames = feedDict.names();
for (var _iterator2 = _createForOfIteratorHelperLoose(outputNames), _step2; !(_step2 = _iterator2()).done;) {
var outputName = _step2.value;
if (feedNames.indexOf(outputName) !== -1) {
finalOutputs.push(feedDict.getValue(outputName));
} else {
finalOutputs.push(null);
}
}
if (probe != null) {
// For optional probing of memory footprint during execution.
probe.maxNumTensors = -Infinity;
probe.minNumTensors = Infinity;
} // Check cache.
var fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().join(',');
var sorted;
var recipientCounts;
if (cachedSorted[fetchAndFeedKey] == null) {
// Cache doesn't contain the desired combination of fetches. Compute
// topological sort for the combination for the first time.
var out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
sorted = out.sorted;
recipientCounts = out.recipientCounts; // Store results in cache for future use.
cachedSorted[fetchAndFeedKey] = sorted;
cachedRecipientCounts[fetchAndFeedKey] = recipientCounts;
}
sorted = cachedSorted[fetchAndFeedKey];
recipientCounts = {};
if (!training) {
Object.assign(recipientCounts, cachedRecipientCounts[fetchAndFeedKey]);
}
var internalFeedDict = new FeedDict(feedDict); // Start iterative execution on the topologically-sorted SymbolicTensors.
for (var i = 0; i < sorted.length; ++i) {
if (probe != null) {
// For optional probing of memory usage during execution.
var numTensors = memory().numTensors;
if (numTensors > probe.maxNumTensors) {
probe.maxNumTensors = numTensors;
}
if (numTensors < probe.minNumTensors) {
probe.minNumTensors = numTensors;
}
}
var symbolic = sorted[i];
var srcLayer = symbolic.sourceLayer;
if (srcLayer instanceof InputLayer) {
continue;
}
var inputValues = [];
var inputMasks = [];
var tensorsToDispose = [];
var maskExists = false;
for (var _iterator3 = _createForOfIteratorHelperLoose(symbolic.inputs), _step3; !(_step3 = _iterator3()).done;) {
var input = _step3.value;
var value = internalFeedDict.getValue(input);
var mask = internalFeedDict.getMask(input);
inputValues.push(value);
inputMasks.push(mask);
if (mask != null) {
maskExists = true;
}
if (!training) {
recipientCounts[input.name]--;
if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) && outputNames.indexOf(input.name) === -1 && !value.isDisposed && input.sourceLayer.stateful !== true) {
tensorsToDispose.push(value);
}
}
}
if (maskExists) {
kwargs = kwargs || {};
kwargs['mask'] = inputMasks[0];
}
var outputTensors = toList(srcLayer.apply(inputValues, kwargs));
var outputMask = null;
if (srcLayer.supportsMasking) {
outputMask = srcLayer.computeMask(inputValues, inputMasks);
}
var layerOutputs = getNodeOutputs(symbolic);
var outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
for (var _i = 0; _i < outputSymbolicTensors.length; ++_i) {
if (!internalFeedDict.hasKey(outputSymbolicTensors[_i])) {
internalFeedDict.add(outputSymbolicTensors[_i], outputTensors[_i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
}
var index = outputNames.indexOf(outputSymbolicTensors[_i].name);
if (index !== -1) {
finalOutputs[index] = outputTensors[_i];
}
}
if (!training) {
// Clean up Tensors that are no longer needed.
dispose(tensorsToDispose);
}
} // NOTE(cais): Unlike intermediate tensors, we don't discard mask
// tensors as we go, because these tensors are sometimes passed over a
// series of mutliple layers, i.e., not obeying the immediate input
// relations in the graph. If this becomes a memory-usage concern,
// we can improve this in the future.
internalFeedDict.disposeMasks();
return arrayFetches ? finalOutputs : finalOutputs[0];
}
/**
* Sort the `SymbolicTensor`s topologically, for an array of fetches.
*
* This function calls getTopologicalSortAndRecipientCountsForOneFetch and
* merges their results.
*
* @param fetch The array of fetches requested. Must be a non-empty array.
* @param feedDict The dictionary of fed values.
* @returns sorted: Topologically-sorted array of SymbolicTensors.
* recipientCounts: Recipient counts for all SymbolicTensors in `sorted`.
*/
function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
assert(fetches != null && fetches.length > 0, function () {
return "Expected at least one fetch, got none";
});
var finalSorted = [];
var finalRecipientMap = {};
if (fetches.length === 1) {
// Special-casing 1 fetch for efficiency.
var out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
finalSorted = out.sorted;
finalRecipientMap = out.recipientMap;
} else {
var visited = new Set();
for (var _iterator4 = _createForOfIteratorHelperLoose(fetches), _step4; !(_step4 = _iterator4()).done;) {
var fetch = _step4.value;
var _getTopologicalSortAn = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict),
sorted = _getTopologicalSortAn.sorted,
recipientMap = _getTopologicalSortAn.recipientMap; // Merge sorted SymbolicTensor Arrays.
for (var _iterator5 = _createForOfIteratorHelperLoose(sorted), _step5; !(_step5 = _iterator5()).done;) {
var symbolicTensor = _step5.value;
if (!visited.has(symbolicTensor.name)) {
finalSorted.push(symbolicTensor);
visited.add(symbolicTensor.name);
}
} // Merge recipient maps.
var _loop = function _loop(name) {
if (finalRecipientMap[name] == null) {
finalRecipientMap[name] = new Set();
}
recipientMap[name].forEach(function (recipient) {
return finalRecipientMap[name].add(recipient);
});
};
for (var name in recipientMap) {
_loop(name);
}
}
}
return {
sorted: finalSorted,
recipientCounts: recipientMap2Counts(finalRecipientMap)
};
}
function recipientMap2Counts(recipientMap) {
var recipientCounts = {};
for (var name in recipientMap) {
recipientCounts[name] = recipientMap[name].size;
}
return recipientCounts;
}
/**
* Sort the `SymbolicTensor`s topologically, for a single fetch.
*
* This helper function processes the upstream SymbolicTensors of a single
* fetch.
*
* @param fetch The single fetch requested.
* @param feedDict The dictionary of fed values.
* @returns sorted: Topologically-sorted array of SymbolicTensors.
* recipientMap: Recipient names for all SymbolicTensors in `sorted`.
*/
function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
var visited = new Set();
var sorted = [];
var recipientMap = {}; // Put keys of the feedDict into visited first, so they don't have to be
// walked. This is needed in case where there are feeds for intermediate
// SymbolicTensors of the graph.
for (var _iterator6 = _createForOfIteratorHelperLoose(feedDict.names()), _step6; !(_step6 = _iterator6()).done;) {
var key = _step6.value;
visited.add(key);
}
var stack = [];
var marks = []; // Initial population of stack and marks.
stack.push(fetch);
while (stack.length > 0) {
var top = stack[stack.length - 1];
if (visited.has(top.name)) {
stack.pop();
continue;
}
var topIsMarked = marks[marks.length - 1] === stack.length - 1;
if (top.inputs.length === 0 || topIsMarked) {
// Input SymbolicTensor or all children have been visited.
stack.pop();
sorted.push(top);
visited.add(top.name);
if (topIsMarked) {
marks.pop();
}
} else {
// A non-input SymbolicTensor whose upstream SymbolicTensors haven't
// been visited yet. Push them onto the stack.
marks.push(stack.length - 1);
for (var _iterator7 = _createForOfIteratorHelperLoose(top.inputs), _step7; !(_step7 = _iterator7()).done;) {
var input = _step7.value;
// Increment the recipient count. Note that this needs to happen
// regardless of whether the SymbolicTensor has been visited before.
if (recipientMap[input.name] == null) {
recipientMap[input.name] = new Set();
}
recipientMap[input.name].add(top.name);
if (visited.has(input.name)) {
continue; // Avoid repeated visits to the same SymbolicTensor.
}
stack.push(input);
}
}
}
return {
sorted: sorted,
recipientMap: recipientMap
};
}
/**
* Get the symbolic output tensors of the node to which a given fetch belongs.
* @param fetch The fetched symbolic tensor.
* @returns The Array of symbolic tensors output by the node to which `fetch`
* belongs.
*/
function getNodeOutputs(fetch) {
var layerOutputs;
if (fetch.sourceLayer.inboundNodes.length === 1) {
layerOutputs = fetch.sourceLayer.output;
} else {
var nodeIndex = null;
for (var i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
for (var _iterator8 = _createForOfIteratorHelperLoose(fetch.sourceLayer.inboundNodes[i].outputTensors), _step8; !(_step8 = _iterator8()).done;) {
var outputTensor = _step8.value;
if (outputTensor.id === fetch.id) {
nodeIndex = i;
break;
}
}
}
layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
}
return layerOutputs;
}
/**
* A Container is a directed acyclic graph of layers.
*
* It is the topological form of a "model". A LayersModel
* is simply a Container with added training routines.
*
*/
var Container = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Container, _Layer);
function Container(args) {
var _this;
// No args passed to super's constructor.
_this = _Layer.call(this, {}) || this;
_this.containerNodes = new Set();
_this.name = args.name;
if (_this.name == null) {
var prefix = _this.getClassName().toLowerCase();
_this.name = getUid(prefix);
}
_this.supportsMasking = false;
_this.trainable_ = true; // TODO(michaelterry): Initialize perInputLosses/Updates here.
// Container-specific properties.
if (Array.isArray(args.inputs)) {
_this.inputs = args.inputs.slice();
} else {
_this.inputs = [args.inputs];
}
if (Array.isArray(args.outputs)) {
_this.outputs = args.outputs.slice();
} else {
_this.outputs = [args.outputs];
} // Check for redundancy in inputs.
if (unique$1(_this.inputs).length !== _this.inputs.length) {
throw new ValueError('The list of inputs passed to the model is ' + 'redundant. All inputs should only appear once. Found: ' + ("" + _this.inputs.map(function (x) {
return x.name;
})));
} // Check for redundancy in outputs.
if (unique$1(_this.outputs).length !== _this.outputs.length) {
console.warn('The list of outputs passed to the model is redundant. ' + 'All outputs should only appear once. Found: ' + ("" + _this.outputs.map(function (x) {
return x.name;
})));
}
/*
List of initial layers (1 to 1 mapping with this.inputs, hence the same
layer might appear twice)
*/
_this.inputLayers = [];
_this.inputLayersNodeIndices = [];
_this.inputLayersTensorIndices = [];
/*
List of layers (1 to 1 mapping with this.outputs, hence the same layer
might appear twice)
*/
_this.outputLayers = [];
_this.outputLayersNodeIndices = [];
_this.outputLayersTensorIndices = [];
/*
All layers in order of horizontal graph traversal. Entries are unique.
Includes input and output layers.
*/
_this.layers = [];
/*
References to container layers that were constructed internally. We need
these to properly dispose of tensors from nested containers.
*/
_this.internalContainerRefs = []; // TODO(michaelterry): Determine if caching still needed with eager
// backend.
/*
This is for performance optimization when calling the Container on new
inputs. Every time the Container is called on a set on input tensors,
we compute the output tensors, output masks and output shapes in one pass,
then cache them here. When one of these outputs is queried later,
we retrieve it from there instead of recomputing it.
*/
// this.outputTensorCache = {};
// this.outputShapeCache = {};
// Build this.outputLayers:
for (var _iterator = _createForOfIteratorHelperLoose(_this.outputs), _step; !(_step = _iterator()).done;) {
var x = _step.value;
var _layer2 = x.sourceLayer;
var nodeIndex = x.nodeIndex;
var tensorIndex = x.tensorIndex;
_this.outputLayers.push(_layer2);
_this.outputLayersNodeIndices.push(nodeIndex);
_this.outputLayersTensorIndices.push(tensorIndex);
} // TODO(michaelterry): Add output mask cache code.
// Build this.inputLayers:
for (var _iterator2 = _createForOfIteratorHelperLoose(_this.inputs), _step2; !(_step2 = _iterator2()).done;) {
var _x = _step2.value;
var _layer3 = _x.sourceLayer;
var _nodeIndex2 = _x.nodeIndex;
var _tensorIndex2 = _x.tensorIndex;
/*
It's supposed to be an input layer, so only one node
and one tensor output.
*/
assert$1(_nodeIndex2 === 0, 'input layer has >1 nodes');
assert$1(_tensorIndex2 === 0, 'input layer has >1 tensors');
_this.inputLayers.push(_layer3);
_this.inputLayersNodeIndices.push(_nodeIndex2);
_this.inputLayersTensorIndices.push(_tensorIndex2);
} // Build this.inputNames and this.outputNames.
_this.inputNames = [];
_this.outputNames = [];
_this.feedInputShapes = [];
_this.feedInputNames = [];
_this.feedOutputNames = [];
for (var i = 0; i < _this.inputLayers.length; i++) {
var layer = _this.inputLayers[i]; // Check that layer is an InputLayer.
if (!(layer instanceof InputLayer)) {
throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' + ("Received inputs: " + args.inputs + ". ") + ("Input " + i + " (0-based) originates ") + ("from layer type " + layer.getClassName() + "."));
}
_this.inputNames.push(layer.name);
_this.feedInputShapes.push(layer.batchInputShape);
_this.feedInputNames.push(layer.name);
}
for (var _iterator3 = _createForOfIteratorHelperLoose(_this.outputLayers), _step3; !(_step3 = _iterator3()).done;) {
var _layer4 = _step3.value;
_this.outputNames.push(_layer4.name);
}
_this.internalInputShapes = _this.inputs.map(function (x) {
return x.shape;
});
_this.internalOutputShapes = _this.outputs.map(function (x) {
return x.shape;
});
/*
Container_nodes: set of nodes included in the graph (not all nodes
included in the layers are relevant to the current graph).
*/
// ids of all nodes relevant to the Container:
var nodesDepths = {}; // To recover nodes from their ID.
var nodeIDToNode = {};
var layersDepths = {}; // To layers from their ID.
var layerIDToLayer = {};
var layerIndices = {};
var nodesInDecreasingDepth = [];
/**
* Builds a map of the graph of layers.
*
* This recursively updates the map `layerIndices`,
* the list `nodesInDecreasingDepth` and the set `containerNodes`.
*
* @param tensor Some tensor in a graph.
* @param finishedNodes Set of nodes whose subgraphs have been traversed
* completely. Useful to prevent duplicated work.
* @param nodesInProgress Set of nodes that are currently active on the
* recursion stack. Useful to detect cycles.
* @param layer Layer from which `tensor` comes from. If not provided,
* will be obtained from tensor.sourceLayer.
* @param nodeIndex Node index from which `tensor` comes from.
* @param tensorIndex TensorIndex from which `tensor` comes from.
*
* @exception RuntimeError if a cycle is detected.
*/
var buildMapOfGraph = function buildMapOfGraph(tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) {
if (layer == null || nodeIndex == null || tensorIndex == null) {
layer = tensor.sourceLayer;
nodeIndex = tensor.nodeIndex;
tensorIndex = tensor.tensorIndex;
}
var node = layer.inboundNodes[nodeIndex]; // Prevent cycles.
if (nodesInProgress.indexOf(node) !== -1) {
throw new RuntimeError("The tensor " + tensor.name + " at layer \"" + layer.name + "\" " + 'is part of a cycle.');
} // Don't repeat work for shared subgraphs
if (finishedNodes.indexOf(node) !== -1) {
return;
} // Update containerNodes.
_this.containerNodes.add(Container.nodeKey(layer, nodeIndex)); // Store the traversal order for layer sorting.
if (!(layer.id in layerIndices)) {
layerIndices[layer.id] = Object.keys(layerIndices).length;
}
if (nodesInProgress.indexOf(node) === -1) {
nodesInProgress.push(node);
} // Propagate to all previous tensors connected to this node.
var numInboundLayers = node.inboundLayers.length;
for (var _i = 0; _i < numInboundLayers; _i++) {
var x = node.inputTensors[_i];
var _layer = node.inboundLayers[_i];
var _nodeIndex = node.nodeIndices[_i];
var _tensorIndex = node.tensorIndices[_i];
buildMapOfGraph(x, finishedNodes, nodesInProgress, _layer, _nodeIndex, _tensorIndex);
}
finishedNodes.push(node);
while (nodesInProgress.indexOf(node) >= 0) {
nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
}
nodesInDecreasingDepth.push(node);
};
var finishedNodes = [];
var nodesInProgress = [];
for (var _iterator4 = _createForOfIteratorHelperLoose(_this.outputs), _step4; !(_step4 = _iterator4()).done;) {
var _x2 = _step4.value;
buildMapOfGraph(_x2, finishedNodes, nodesInProgress);
}
var reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
for (var _iterator5 = _createForOfIteratorHelperLoose(reversedNodesInDecreasingDepth), _step5; !(_step5 = _iterator5()).done;) {
var node = _step5.value;
nodeIDToNode[node.id] = node; // If the depth is not set, the node has no outbound nodes (depth 0).
if (!(node.id in nodesDepths)) {
nodesDepths[node.id] = 0;
}
var _depth2 = nodesDepths[node.id]; // Update the depth of the corresponding layer
var previousDepth = layersDepths[node.outboundLayer.id] == null ? 0 : layersDepths[node.outboundLayer.id];
/*
If we've seen this layer before at a higher depth, we should use that
depth instead of the node depth. This is necessary for shared layers
that have inputs at different depth levels in the graph.
*/
_depth2 = Math.max(_depth2, previousDepth);
layersDepths[node.outboundLayer.id] = _depth2;
layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
nodesDepths[node.id] = _depth2; // Update the depth of inbound nodes.
for (var _i2 = 0; _i2 < node.inboundLayers.length; _i2++) {
var inboundLayer = node.inboundLayers[_i2];
var _nodeIndex3 = node.nodeIndices[_i2];
var inboundNode = inboundLayer.inboundNodes[_nodeIndex3];
var _previousDepth = nodesDepths[inboundNode.id] == null ? 0 : nodesDepths[inboundNode.id];
nodesDepths[inboundNode.id] = Math.max(_depth2 + 1, _previousDepth);
nodeIDToNode[inboundNode.id] = inboundNode;
}
} // Build a dict {depth: list of nodes with this depth}
var nodesByDepth = {};
for (var nodeID in nodesDepths) {
var depth = nodesDepths[nodeID];
if (!(depth in nodesByDepth)) {
nodesByDepth[depth] = [];
}
nodesByDepth[depth].push(nodeIDToNode[nodeID]);
} // Build a dict {depth: list of layers with this depth}
var layersByDepth = {};
for (var layerID in layersDepths) {
var _depth = layersDepths[layerID];
if (!(_depth in layersByDepth)) {
layersByDepth[_depth] = [];
}
layersByDepth[_depth].push(layerIDToLayer[layerID]);
} // Get sorted list of layer depths.
var depthKeys = Object.keys(layersByDepth).map(function (x) {
return parseInt(x, 10);
}).sort(reverseNumberCompare); // Set this.layers and this.layersByDepth.
_this.layers = [];
for (var _iterator6 = _createForOfIteratorHelperLoose(depthKeys), _step6; !(_step6 = _iterator6()).done;) {
var _depth3 = _step6.value;
var layersForDepth = layersByDepth[_depth3]; // Container.layers needs to have a deterministic order:
// here we order them by traversal order.
layersForDepth.sort(function (a, b) {
var aIndex = layerIndices[a.id];
var bIndex = layerIndices[b.id];
if (aIndex < bIndex) {
return -1;
}
if (aIndex > bIndex) {
return 1;
}
return 0;
});
for (var _iterator9 = _createForOfIteratorHelperLoose(layersForDepth), _step9; !(_step9 = _iterator9()).done;) {
var _layer5 = _step9.value;
if (_layer5 instanceof Container) {
_this.internalContainerRefs.push(_layer5);
}
_this.layers.push(_layer5);
}
}
_this.layersByDepth = layersByDepth; // Get sorted list of node depths;
depthKeys = Object.keys(nodesByDepth).map(function (x) {
return parseInt(x, 10);
}).sort(reverseNumberCompare); // Check that all tensors required are computable.
// computable_tensors: all tensors in the graph
// that can be computed from the inputs provided.
var computableTensors = _this.inputs.slice(); // To provide a better error msg.
var layersWithCompleteInput = [];
for (var _iterator7 = _createForOfIteratorHelperLoose(depthKeys), _step7; !(_step7 = _iterator7()).done;) {
var _depth4 = _step7.value;
for (var _iterator10 = _createForOfIteratorHelperLoose(nodesByDepth[_depth4]), _step10; !(_step10 = _iterator10()).done;) {
var _node = _step10.value;
var _layer6 = _node.outboundLayer;
if (_layer6 != null) {
for (var _iterator11 = _createForOfIteratorHelperLoose(_node.inputTensors), _step11; !(_step11 = _iterator11()).done;) {
var _x3 = _step11.value;
if (computableTensors.indexOf(_x3) === -1) {
throw new RuntimeError("Graph disconnected: cannot obtain value for tensor " + _x3 + (" at layer \"" + _layer6.name + "\". ") + 'The following previous layers were accessed without ' + ("issue: " + layersWithCompleteInput));
}
}
for (var _iterator12 = _createForOfIteratorHelperLoose(_node.outputTensors), _step12; !(_step12 = _iterator12()).done;) {
var _x4 = _step12.value;
computableTensors.push(_x4);
}
layersWithCompleteInput.push(_layer6.name);
}
}
} // Set this.containerNodes and this.nodesByDepth.
_this.nodesByDepth = nodesByDepth; // Ensure name unicity, which will be crucial for serialization
// (since serialized nodes refer to layers by their name).
var allNames = _this.layers.map(function (x) {
return x.name;
});
var _loop = function _loop() {
var name = _step8.value;
var numOccurrences = allNames.filter(function (x) {
return x === name;
}).length;
if (numOccurrences !== 1) {
throw new RuntimeError("The name \"" + name + "\" is used " + numOccurrences + " times " + 'in the model. All layer names should be unique. Layer names: ' + JSON.stringify(allNames));
}
};
for (var _iterator8 = _createForOfIteratorHelperLoose(allNames), _step8; !(_step8 = _iterator8()).done;) {
_loop();
} // Layer parameters.
// The new container starts with a single inbound node
// for its inputs, and no outbound nodes.
// Will be appended to by future calls to apply().
_this.outboundNodes = []; // Will be appended to below, and by future calls to apply().
_this.inboundNodes = []; // Create the node linking internal inputs to internal outputs.
// (This call has side effects.)
// tslint:disable-next-line:no-unused-expression
new Node({
outboundLayer: _assertThisInitialized(_this),
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: _this.inputs,
outputTensors: _this.outputs,
inputMasks: _this.inputs.map(function (x) {
return null;
}),
outputMasks: _this.outputs.map(function (x) {
return null;
}),
inputShapes: _this.inputs.map(function (x) {
return x.shape;
}),
outputShapes: _this.outputs.map(function (x) {
return x.shape;
})
});
_this.built = true;
_this._refCount = 1; // The ref count of a container always start at 1.
return _this;
}
var _proto = Container.prototype;
_proto.assertNotDisposed = function assertNotDisposed() {
if (this._refCount === 0) {
throw new Error("Container '" + this.name + "' is already disposed.");
}
}
/**
* Attempt to dispose a LayersModel's weights.
*
* This method decrease the reference count of the LayersModel object by 1.
*
* A LayersModel is reference-counted. Its reference count is incremented by 1
* when it is first constructed and when it is used as a Layer of another
* LayersModel.
*
* If the reference count of a LayersModel becomes 0, the `dispose` method of
* all its constituent `Layer`s will be called.
*
* Note: If the reference count is greater than 0 after the decrement, the
* `dispose` method of its constituent `Layer`s will *not* be called.
*
* After a LayersModel is disposed, it cannot be used in calls such as
* 'predict`, `evaluate` or `fit` anymore.
*
* @returns A DisposeResult Object with the following fields:
* - refCountAfterDispose: The reference count of the LayersModel after this
* `dispose()` call.
* - numDisposedVariables: Number of `tf.Variable`s (i.e., weights) disposed
* during this `dispose()` call.
* @throws {Error} If the layer is not built yet, or if the LayersModel has
* already been disposed.
*/
;
_proto.dispose = function dispose() {
this.assertNotDisposed();
var result = {
refCountAfterDispose: null,
numDisposedVariables: 0
};
if (--this._refCount === 0) {
for (var _iterator13 = _createForOfIteratorHelperLoose(this.layers), _step13; !(_step13 = _iterator13()).done;) {
var layer = _step13.value;
result.numDisposedVariables += layer.dispose().numDisposedVariables;
} // Call dispose on each internally created container layer again to ensure
// their refCounts hit zero and their tensors are subsequently deleted.
for (var _iterator14 = _createForOfIteratorHelperLoose(this.internalContainerRefs), _step14; !(_step14 = _iterator14()).done;) {
var container = _step14.value;
result.numDisposedVariables += container.dispose().numDisposedVariables;
}
}
result.refCountAfterDispose = this._refCount;
return result;
};
/**
* Loads all layer weights from a JSON object.
*
* Porting Note: HDF5 weight files cannot be directly loaded in JavaScript /
* TypeScript. The utility script at `scripts/pykeras.py` offers means
* to convert them into JSON strings compatible with this method.
* Porting Note: TensorFlow.js Layers supports only loading by name currently.
*
* @param weights A JSON mapping weight names to weight values as nested
* arrays of numbers, or a `NamedTensorMap`, i.e., a JSON mapping weight
* names to `tf.Tensor` objects.
* @param strict Require that the provided weights exactly match those
* required by the container. Default: `true`. Passing `false` means that
* extra weights and missing weights will be silently ignored.
*/
_proto.loadWeights = function loadWeights(weights, strict) {
if (strict === void 0) {
strict = true;
}
var nameToWeight = {};
var totalWeightsCount = 0;
for (var _iterator15 = _createForOfIteratorHelperLoose(this.layers), _step15; !(_step15 = _iterator15()).done;) {
var layer = _step15.value;
for (var _iterator16 = _createForOfIteratorHelperLoose(layer.weights), _step16; !(_step16 = _iterator16()).done;) {
var weight = _step16.value;
if (nameToWeight[weight.originalName] != null) {
throw new ValueError("Duplicate weight name: " + weight.originalName);
}
nameToWeight[weight.originalName] = weight;
totalWeightsCount++;
}
}
var weightValueTuples = [];
for (var name in weights) {
// TF 2.2.0 added cell name to the weight name in the format of
// layer_name/cell_name/weight_name, we need to remove
// the inner cell name.
var validatedName = name;
if (nameToWeight[name] == null) {
var tokens = name.split('/');
var shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
validatedName = shortenNameArray.join('/');
}
if (nameToWeight[validatedName] != null) {
weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
} else if (strict) {
throw new ValueError("Provided weight data has no target variable: " + name);
}
delete nameToWeight[validatedName];
}
if (strict) {
// Check that all weights are set.
var unsetNames = [];
for (var _name in nameToWeight) {
unsetNames.push(_name);
}
if (unsetNames.length > 0) {
throw new ValueError(unsetNames.length + " of " + totalWeightsCount + " weights are not set: " + ("" + unsetNames));
}
}
batchSetValue(weightValueTuples);
}
/**
* Util shared between different serialization methods.
* @returns LayersModel config with Keras version information added.
*/
;
_proto.updatedConfig = function updatedConfig() {
var theConfig = this.getConfig();
var modelConfig = {};
modelConfig['className'] = this.getClassName();
modelConfig['config'] = theConfig;
modelConfig['kerasVersion'] = "tfjs-layers " + version$2; // TODO(nielsene): Replace something like K.backend() once
// possible.
modelConfig['backend'] = 'TensorFlow.js';
return modelConfig;
}
/**
* Returns a JSON string containing the network configuration.
*
* To load a network from a JSON save file, use
* models.modelFromJSON(jsonString);
* @param extraJsonArgs Unused in tfjs-layers, maintained for PyKeras
* @param returnString Whether the return value should be stringified
* (default: `true`).
* @returns a JSON string if `returnString` (default), or a JSON object if
* `!returnString`.
*/
// tslint:disable-next-line:no-any
;
_proto.toJSON = function toJSON(unused, returnString) {
if (returnString === void 0) {
returnString = true;
}
var modelConfig = convertTsToPythonic(this.updatedConfig());
return returnString ? JSON.stringify(modelConfig) : modelConfig;
}
/**
* Call the model on new inputs.
*
* In this case `call` just reapplies all ops in the graph to the new inputs
* (e.g. build a new computational graph from the provided inputs).
*
* @param inputs A tensor or list of tensors.
* @param mask A mask or list of masks. A mask can be either a tensor or null
* (no mask).
*
* @return A tensor if there is a single output, or a list of tensors if there
* are more than one outputs.
*/
;
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
inputs = toList(inputs);
var feedDict = new FeedDict();
for (var i = 0; i < _this2.inputs.length; ++i) {
feedDict.add(_this2.inputs[i], inputs[i]);
}
return execute(_this2.outputs, feedDict, kwargs);
});
}
/**
* Computes an output mask tensor.
*
* @param inputs Tensor or list of tensors.
* @param mask Tensor or list of tensors.
*
* @return null or a tensor (or list of tensors, one per output tensor of the
* layer).
*/
;
_proto.computeMask = function computeMask(inputs, mask) {
var _this3 = this;
return tidy(function () {
inputs = toList(inputs);
var masks;
if (mask == null) {
masks = pyListRepeat(null, inputs.length);
} else {
masks = toList(mask);
} // TODO(michaelterry): Add support for mask caching.
return _this3.runInternalGraph(inputs, masks)[1];
});
}
/**
* Computes the output shape of the layer.
*
* Assumes that the layer will be built to match that input shape provided.
*
* @param inputShape A shape (tuple of integers) or a list of shape tuples
* (one per output tensor of the layer). Shape tuples can include null for
* free dimensions, instead of an integer.
*/
;
_proto.computeOutputShape = function computeOutputShape(inputShape) {
var inputShapes = normalizeShapeList(inputShape);
if (inputShapes.length !== this.inputLayers.length) {
throw new ValueError("Invalid inputShape argument " + inputShape + ": " + ("model has " + this.inputLayers.length + " tensor inputs."));
} // TODO(michaelterry): Add caching
var layersToOutputShapes = {};
for (var i = 0; i < inputShapes.length; i++) {
var layer = this.inputLayers[i];
var _inputShape = inputShapes[i]; // It's an input layer: computeOutputShape is identity,
// and there is only one node and one tensor output.
var shapeKey = layer.name + '_0_0';
layersToOutputShapes[shapeKey] = _inputShape;
}
var depthKeys = Object.keys(this.nodesByDepth).map(function (x) {
return parseInt(x, 10);
}).sort(reverseNumberCompare); // Iterate over nodes, by depth level.
if (depthKeys.length > 1) {
for (var _iterator17 = _createForOfIteratorHelperLoose(depthKeys), _step17; !(_step17 = _iterator17()).done;) {
var depth = _step17.value;
var nodes = this.nodesByDepth[depth];
for (var _iterator18 = _createForOfIteratorHelperLoose(nodes), _step18; !(_step18 = _iterator18()).done;) {
var node = _step18.value;
// This is always a single layer, never a list.
var _layer7 = node.outboundLayer;
if (this.inputLayers.map(function (x) {
return x.id;
}).indexOf(_layer7.id) !== -1) {
// We've already covered the input layers a few lines above.
continue;
} // Potentially redundant list, same size of node.inputTensors.
var _inputShapes = [];
for (var j = 0; j < node.inboundLayers.length; j++) {
var inboundLayer = node.inboundLayers[j];
var _nodeIndex4 = node.nodeIndices[j];
var tensorIndex = node.tensorIndices[j];
var _shapeKey = inboundLayer.name + "_" + _nodeIndex4 + "_" + tensorIndex;
var _inputShape2 = layersToOutputShapes[_shapeKey];
_inputShapes.push(_inputShape2);
}
var outputShape = _layer7.computeOutputShape(singletonOrArray(_inputShapes));
var _outputShapes = normalizeShapeList(outputShape);
var nodeIndex = _layer7.inboundNodes.indexOf(node);
for (var _j = 0; _j < _outputShapes.length; _j++) {
var _shapeKey2 = _layer7.name + "_" + nodeIndex + "_" + _j;
layersToOutputShapes[_shapeKey2] = _outputShapes[_j];
}
}
}
} // Read final output shapes from layersToOutputShapes.
var outputShapes = [];
var outputShapeKeys = [];
for (var _i3 = 0; _i3 < this.outputLayers.length; _i3++) {
var _layer8 = this.outputLayers[_i3];
var _nodeIndex5 = this.outputLayersNodeIndices[_i3];
var _tensorIndex3 = this.outputLayersTensorIndices[_i3];
var _shapeKey3 = _layer8.name + "_" + _nodeIndex5 + "_" + _tensorIndex3;
outputShapeKeys.push(_shapeKey3);
}
for (var _i4 = 0; _i4 < outputShapeKeys.length; _i4++) {
var key = outputShapeKeys[_i4];
assert$1(key in layersToOutputShapes);
outputShapes.push(layersToOutputShapes[key]);
} // TODO(michaelterry): Update cache
return singletonOrArray(outputShapes);
}
/**
* Computes output tensors for new inputs.
*
* Note:
* - Expects `inputs` to be a list (potentially with 1 element).
*
* @param inputs List of tensors
* @param masks List of masks (tensors or null).
* @return Three lists: outputTensors, outputMasks, outputShapes
*/
;
_proto.runInternalGraph = function runInternalGraph(inputs, masks) {
if (masks == null) {
masks = pyListRepeat(null, inputs.length);
} // Dictionary mapping reference tensors to tuples
// (computed tensor, compute mask)
// we assume a 1:1 mapping from tensor to mask
// TODO: raise exception when a `.computeMask()` call
// does not return a list the same size as `call`
var tensorMap = {};
for (var i = 0; i < this.inputs.length; ++i) {
var x = this.inputs[i];
var y = inputs[i];
var mask = masks[i];
tensorMap[x.id] = [y, mask];
}
var depthKeys = Object.keys(this.nodesByDepth).map(function (x) {
return parseInt(x, 10);
}).sort(reverseNumberCompare);
for (var _iterator19 = _createForOfIteratorHelperLoose(depthKeys), _step19; !(_step19 = _iterator19()).done;) {
var depth = _step19.value;
var nodes = this.nodesByDepth[depth];
for (var _iterator21 = _createForOfIteratorHelperLoose(nodes), _step21; !(_step21 = _iterator21()).done;) {
var node = _step21.value;
// This is always a single layer, never a list.
var layer = node.outboundLayer;
var referenceInputTensors = node.inputTensors;
var referenceOutputTensors = node.outputTensors; // If all previous input tensors are available in tensorMap,
// then call node.inboundLayer on them.
// List of tuples [input, mask]:
var computedData = new Array();
for (var _iterator22 = _createForOfIteratorHelperLoose(referenceInputTensors), _step22; !(_step22 = _iterator22()).done;) {
var _x6 = _step22.value;
if (_x6.id in tensorMap) {
computedData.push(tensorMap[_x6.id]);
}
}
if (computedData.length === referenceInputTensors.length) {
// TODO(michaelterry): Add K.name_scope here, if we need it.
var kwargs = {};
var computedTensors = void 0;
var computedMasks = void 0;
var _outputTensors = void 0;
var _outputMasks = void 0; // call layer
if (node.callArgs != null) {
kwargs = node.callArgs;
}
if (computedData.length === 1) {
var _computedData$ = computedData[0],
computedTensor = _computedData$[0],
computedMask = _computedData$[1];
if (kwargs['mask'] == null) {
kwargs['mask'] = computedMask;
}
_outputTensors = toList(layer.call(computedTensor, kwargs));
_outputMasks = toList(layer.computeMask(computedTensor, computedMask));
computedTensors = [computedTensor];
computedMasks = [computedMask];
} else {
computedTensors = computedData.map(function (x) {
return x[0];
});
computedMasks = computedData.map(function (x) {
return x[1];
});
if (kwargs['mask'] == null) {
kwargs['mask'] = computedMasks;
}
_outputTensors = toList(layer.call(computedTensors, kwargs));
_outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
}
if (layer.activityRegularizer) {
throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' + 'presence of activity regularizer(s) is not supported yet.');
} // TODO(michaelterry): Add model updates and losses
// Update tensor map.
for (var _i5 = 0; _i5 < referenceOutputTensors.length; ++_i5) {
var _x5 = referenceOutputTensors[_i5];
var _y = _outputTensors[_i5];
var _mask = _outputMasks[_i5];
tensorMap[_x5.id] = [_y, _mask];
}
}
}
}
var outputTensors = [];
var outputMasks = [];
var outputShapes = [];
for (var _iterator20 = _createForOfIteratorHelperLoose(this.outputs), _step20; !(_step20 = _iterator20()).done;) {
var _x7 = _step20.value;
assert$1(_x7.id in tensorMap, "Could not compute output " + _x7.name + " : " + _x7.id);
var _tensorMap$_x7$id = tensorMap[_x7.id],
tensor = _tensorMap$_x7$id[0],
_mask2 = _tensorMap$_x7$id[1];
outputShapes.push(tensor.shape);
outputTensors.push(tensor);
outputMasks.push(_mask2);
} // TODO(michaelterry): Add support for caches.
return [outputTensors, outputMasks, outputShapes];
}
/**
* Builds a map of internal node keys to node ordering.
* Used in serializaion a node orderings may change as unused nodes are
* dropped. Porting Note: This helper method was pulled out of getConfig to
* improve readability.
* @param layers An array of Layers in the model.
* @returns Map of Node Keys to index order within the layer.
*/
;
_proto.buildNodeConversionMap = function buildNodeConversionMap(layers) {
var nodeConversionMap = {};
var keptNodes;
for (var _iterator23 = _createForOfIteratorHelperLoose(this.layers), _step23; !(_step23 = _iterator23()).done;) {
var layer = _step23.value;
keptNodes = layer instanceof Container ? 1 : 0;
for (var originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
var nodeKey = Container.nodeKey(layer, originalNodeIndex);
if (this.containerNodes.has(nodeKey)) {
// i.e. we mark it to be saved
nodeConversionMap[nodeKey] = keptNodes;
keptNodes += 1;
}
}
}
return nodeConversionMap;
}
/**
* Retrieves a layer based on either its name (unique) or index.
*
* Indices are based on order of horizontal graph traversal (bottom-up).
*
* If both `name` and `index` are specified, `index` takes precedence.
*
* @param name Name of layer.
* @param index Index of layer.
* @returns A Layer instance.
* @throws ValueError: In case of invalid layer name or index.
*
* @doc {
* heading: 'Layers',
* subheading: 'Classes',
* namespace: 'layers',
* subclasses: ['LayersModel']
* }
*/
;
_proto.getLayer = function getLayer(name, index) {
if (index != null) {
if (this.layers.length <= index) {
throw new ValueError("Was asked to retrieve layer at index " + index + ", but model only " + ("has " + this.layers.length + " layer(s)."));
} else {
return this.layers[index];
}
} else {
if (name == null) {
throw new ValueError('Provide either a layer name or layer index');
}
}
for (var _iterator24 = _createForOfIteratorHelperLoose(this.layers), _step24; !(_step24 = _iterator24()).done;) {
var layer = _step24.value;
if (layer.name === name) {
return layer;
}
}
throw new ValueError("No such layer: " + name);
}
/**
* Retrieves the Container's current loss values.
*
* Used for regularizers during training.
*/
;
_proto.calculateLosses = function calculateLosses() {
var _this4 = this;
// Porting Node: This is an augmentation to Container.loss in PyKeras.
// In PyKeras, Container.loss returns symbolic tensors. Here a concrete
// Tensor (specifically Scalar) values are returned. This is due to the
// imperative backend.
return tidy(function () {
var losses = [];
for (var _iterator25 = _createForOfIteratorHelperLoose(_this4.layers), _step25; !(_step25 = _iterator25()).done;) {
var layer = _step25.value;
for (var nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
var nodeKey = Container.nodeKey(layer, nodeIndex);
if (_this4.containerNodes.has(nodeKey)) {
losses.push.apply(losses, layer.calculateLosses());
}
}
} // TODO(cais): Add any unconditional model-level losses?
return losses;
});
};
_proto.getConfig = function getConfig() {
var config = {
name: this.name
}; // Build a map from layer unique name (self._node_key)
// to the index of the nodes that are saved in the config.
// Only nodes in container_nodes are saved.
var nodeConversionMap = this.buildNodeConversionMap(this.layers); // Serialize and save the layers in layerConfigs
var layerConfigs = [];
for (var _iterator26 = _createForOfIteratorHelperLoose(this.layers), _step26; !(_step26 = _iterator26()).done;) {
var _layer10 = _step26.value;
var layerClassName = _layer10.getClassName();
var layerConfig = _layer10.getConfig();
var filteredInboundNodes = [];
for (var originalNodeIndex = 0; originalNodeIndex < _layer10.inboundNodes.length; originalNodeIndex++) {
var node = _layer10.inboundNodes[originalNodeIndex];
var _nodeKey2 = Container.nodeKey(_layer10, originalNodeIndex);
var kwargs = {};
if (this.containerNodes.has(_nodeKey2)) {
// The node is relevant to the model:
// add to filteredInboundNodes.
if (node.callArgs) {
try {
JSON.stringify(node.callArgs);
kwargs = node.callArgs;
} catch (err) {
console.warn("Layer " + _layer10.name + " was passed " + "non-serializable keyword arguments: " + (node.callArgs + ". They will not be included ") + "in the serialized model (and thus will be " + "missing at deserialization time).");
kwargs = {};
}
}
if (node.inboundLayers.length > 0) {
var nodeData = [];
for (var _i7 = 0; _i7 < node.inboundLayers.length; _i7++) {
var inboundLayer = node.inboundLayers[_i7];
var _nodeIndex7 = node.nodeIndices[_i7];
var _tensorIndex5 = node.tensorIndices[_i7];
var _nodeKey3 = Container.nodeKey(inboundLayer, _nodeIndex7);
var _newNodeIndex2 = nodeConversionMap[_nodeKey3];
if (_newNodeIndex2 == null) {
_newNodeIndex2 = 0;
}
nodeData.push([inboundLayer.name, _newNodeIndex2, _tensorIndex5, kwargs]);
}
filteredInboundNodes.push(nodeData);
}
}
}
var dict = {};
dict['name'] = _layer10.name;
dict['className'] = layerClassName;
dict['config'] = layerConfig;
dict['inboundNodes'] = filteredInboundNodes;
layerConfigs.push(dict);
}
config['layers'] = layerConfigs; // Gather info about inputs and outputs
var modelInputs = [];
for (var i = 0; i < this.inputLayers.length; i++) {
var layer = this.inputLayers[i];
var nodeIndex = this.inputLayersNodeIndices[i];
var nodeKey = Container.nodeKey(layer, nodeIndex);
if (!this.containerNodes.has(nodeKey)) {
continue;
}
var newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex === null || newNodeIndex === undefined) {
newNodeIndex = 0;
}
var tensorIndex = this.inputLayersTensorIndices[i];
modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
}
config['inputLayers'] = modelInputs;
var modelOutputs = [];
for (var _i6 = 0; _i6 < this.outputLayers.length; _i6++) {
var _layer9 = this.outputLayers[_i6];
var _nodeIndex6 = this.outputLayersNodeIndices[_i6];
var _nodeKey = Container.nodeKey(_layer9, _nodeIndex6);
if (!this.containerNodes.has(_nodeKey)) {
continue;
}
var _newNodeIndex = nodeConversionMap[_nodeKey];
if (_newNodeIndex === null || _newNodeIndex === undefined) {
_newNodeIndex = 0;
}
var _tensorIndex4 = this.outputLayersTensorIndices[_i6];
modelOutputs.push([_layer9.name, _newNodeIndex, _tensorIndex4]);
}
config['outputLayers'] = modelOutputs;
return config;
}
/**
* Instantiates a LayersModel from its config (output of `get_config()`).
* @param cls the class to create
* @param config LayersModel config dictionary.
* @param customObjects An optional dictionary of custom objects.
* @param fastWeightInit Optional flag to use fast weight initialization
* during deserialization. This is applicable to cases in which
* the initialization will be immediately overwritten by loaded weight
* values. Default: `false`.
* @returns A LayersModel instance.
* @throws ValueError: In case of improperly formatted config dict.
*/
/** @nocollapse */
;
Container.fromConfig = function fromConfig(cls, config, customObjects, fastWeightInit) {
if (customObjects === void 0) {
customObjects = {};
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
// Layer instances created during
// the graph reconstruction process
var createdLayers = {}; // Dictionary mapping layer instances to
// node data that specifies a layer call.
// It acts as a queue that maintains any unprocessed
// layer call until it becomes possible to process it
// (i.e. until the input tensors to the call all exist).
var unprocessedNodes = {};
function addUnprocessedNode(layer, nodeData) {
if (!(layer.name in unprocessedNodes)) {
unprocessedNodes[layer.name] = [nodeData];
} else {
unprocessedNodes[layer.name].push(nodeData);
}
}
function processNode(layer, nodeData) {
var inputTensors = [];
var kwargs;
for (var _iterator27 = _createForOfIteratorHelperLoose(nodeData), _step27; !(_step27 = _iterator27()).done;) {
var inputData = _step27.value;
var inboundLayerName = inputData[0];
var inboundNodeIndex = inputData[1];
var inboundTensorIndex = inputData[2];
kwargs = inputData[3] == null ? {} : inputData[3];
if (!(inboundLayerName in createdLayers)) {
addUnprocessedNode(layer, nodeData);
return;
}
var inboundLayer = createdLayers[inboundLayerName];
if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
addUnprocessedNode(layer, nodeData);
return;
}
var inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
} // Call layer on its inputs, thus creating the node
// and building the layer if needed.
// Note: This has Eager vs Graph Implications.
if (inputTensors.length > 0) {
layer.apply(singletonOrArray(inputTensors), kwargs); // was ** kwargs
}
}
/**
* Deserialize a layer, then call it on appropriate inputs.
* @param layerData: layer config dict.
* @throws ValueError: In case of improperly formatted `layer_data`
* dict.
*/
function processLayer(layerData) {
var layerName = layerData['name']; // Instantiate layer.
var layer = deserialize$1(layerData, config['customObjects'] != null ? config['customObjects'] : {});
layer.setFastWeightInitDuringBuild(fastWeightInit);
createdLayers[layerName] = layer; // Gather layer inputs.
var inboundNodesData = layerData['inboundNodes'];
inboundNodesData.forEach(function (nodeData) {
if (!(nodeData instanceof Array)) {
throw new ValueError("Corrupted configuration, expected array for nodeData: " + nodeData);
} // We don't process nodes (i.e. make layer calls)
// on the fly because the inbound node may not yet exist,
// in case of layer shared at different topological depths
// (e.g.a model such as A(B(A(B(x)))))
addUnprocessedNode(layer, nodeData);
});
} // First, we create all layers and enqueue nodes to be processed.
var name = config['name'];
var layersFromConfig = config['layers'];
for (var _iterator28 = _createForOfIteratorHelperLoose(layersFromConfig), _step28; !(_step28 = _iterator28()).done;) {
var _layerData = _step28.value;
processLayer(_layerData);
} // Then we process nodes in order of layer depth.
// Nodes that cannot yet be processed(if the inbound node
// does not yet exist) are re - enqueued, and the process
// is repeated until all nodes are processed.
while (!isObjectEmpty(unprocessedNodes)) {
for (var _iterator29 = _createForOfIteratorHelperLoose(layersFromConfig), _step29; !(_step29 = _iterator29()).done;) {
var layerData = _step29.value;
var layer = createdLayers[layerData['name']];
if (layer.name in unprocessedNodes) {
var currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
delete unprocessedNodes[layer.name];
for (var _iterator30 = _createForOfIteratorHelperLoose(currentUnprocessedNodesForLayer), _step30; !(_step30 = _iterator30()).done;) {
var nodeData = _step30.value;
processNode(layer, nodeData);
}
}
}
}
var inputTensors = [];
var outputTensors = [];
var inputLayersFromConfig = config['inputLayers'];
for (var _iterator31 = _createForOfIteratorHelperLoose(inputLayersFromConfig), _step31; !(_step31 = _iterator31()).done;) {
var _layerData2 = _step31.value;
var layerName = _layerData2[0];
var nodeIndex = _layerData2[1];
var tensorIndex = _layerData2[2];
assert$1(layerName in createdLayers);
var _layer11 = createdLayers[layerName];
var layerOutputTensors = _layer11.inboundNodes[nodeIndex].outputTensors;
inputTensors.push(layerOutputTensors[tensorIndex]);
}
var outputLayersFromConfig = config['outputLayers'];
for (var _iterator32 = _createForOfIteratorHelperLoose(outputLayersFromConfig), _step32; !(_step32 = _iterator32()).done;) {
var _layerData3 = _step32.value;
var _layerName = _layerData3[0];
var _nodeIndex8 = _layerData3[1];
var _tensorIndex6 = _layerData3[2];
assert$1(_layerName in createdLayers);
var _layer12 = createdLayers[_layerName];
var _layerOutputTensors = _layer12.inboundNodes[_nodeIndex8].outputTensors;
outputTensors.push(_layerOutputTensors[_tensorIndex6]);
}
return new cls({
inputs: inputTensors,
outputs: outputTensors,
name: name
});
}
/**
* Determine whether the container is stateful.
*
* Porting Note: this is the equivalent of the stateful @property of
* the Container class in PyKeras.
*/
;
/**
* Reset the state of all stateful constituent layers (if any).
*
* Examples of stateful layers include RNN layers whose `stateful` property
* is set as `true`.
*/
_proto.resetStates = function resetStates() {
var _this5 = this;
tidy(function () {
_this5.layers.forEach(function (layer) {
// tslint:disable:no-any
if (layer.stateful) {
layer.resetStates();
} // tslint:enable:no-any
});
});
};
_createClass(Container, [{
key: "trainable",
get: function get() {
return this.trainable_;
},
set: function set(trainable) {
this.layers.forEach(function (layer) {
// tslint:disable-next-line:no-any
layer._trainableWeights.forEach(function (w) {
return w.trainable = trainable;
});
});
this.trainable_ = trainable;
}
}, {
key: "trainableWeights",
get: function get() {
// Porting Note: This check below is to prevent errors where the
// _trainableWeights inherited from the parent class (Layer) gets
// inadvertently used.
if (this._trainableWeights.length > 0) {
throw new ValueError('Container instance unexpectedly contains _trainableWeights.' + 'The trainable weights of a Container are a union of the ' + 'trainable weights of its consituent Layers. Its own ' + '_trainableWeights must remain an empty Array.');
}
if (!this.trainable) {
return [];
}
var weights = [];
for (var _iterator33 = _createForOfIteratorHelperLoose(this.layers), _step33; !(_step33 = _iterator33()).done;) {
var layer = _step33.value;
weights = weights.concat(layer.trainableWeights);
}
return weights;
}
}, {
key: "nonTrainableWeights",
get: function get() {
var weights = [];
for (var _iterator34 = _createForOfIteratorHelperLoose(this.layers), _step34; !(_step34 = _iterator34()).done;) {
var _layer13 = _step34.value;
weights.push.apply(weights, _layer13.nonTrainableWeights);
}
if (!this.trainable) {
var trainableWeights = [];
for (var _iterator35 = _createForOfIteratorHelperLoose(this.layers), _step35; !(_step35 = _iterator35()).done;) {
var layer = _step35.value;
trainableWeights.push.apply(trainableWeights, layer.trainableWeights);
}
return trainableWeights.concat(weights);
}
return weights;
}
}, {
key: "weights",
get: function get() {
return this.trainableWeights.concat(this.nonTrainableWeights);
}
}, {
key: "stateful",
get: function get() {
// Porting Note: This check is to prevent inadvertent setting of the
// _stateful property of the Container instance.
if (this._stateful) {
throw new ValueError('Container instance unexpectedly has _stateful = true. The ' + 'statefulness of a Container is determined by the Layers it ' + 'contains. Its _stateful property must remain the default false.');
}
for (var _iterator36 = _createForOfIteratorHelperLoose(this.layers), _step36; !(_step36 = _iterator36()).done;) {
var layer = _step36.value;
if (layer.stateful) {
return true;
}
}
return false;
}
}]);
return Container;
}(Layer);
function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
var numOutputs = outputNames.length;
if (xWeight == null || Array.isArray(xWeight) && xWeight.length === 0) {
return outputNames.map(function (name) {
return null;
});
}
if (numOutputs === 1) {
if (Array.isArray(xWeight) && xWeight.length === 1) {
return xWeight;
} else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
return [xWeight[outputNames[0]]];
} else {
return [xWeight];
}
}
if (Array.isArray(xWeight)) {
if (xWeight.length !== numOutputs) {
throw new Error("Provided " + weightType + " is an array of " + xWeight.length + " " + ("element(s), but the model has " + numOutputs + " outputs. ") + "Make sure a set of weights is provided for each model output.");
}
return xWeight;
} else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 && typeof xWeight[Object.keys(xWeight)[0]] === 'object') {
var output = [];
outputNames.forEach(function (outputName) {
if (outputName in xWeight) {
output.push(xWeight[outputName]);
} else {
output.push(null);
}
});
return output;
} else {
throw new Error("The model has multiple (" + numOutputs + ") outputs, " + ("so " + weightType + " must be either an array with ") + (numOutputs + " elements or an object with " + outputNames + " keys. ") + ("Provided " + weightType + " not understood: " + JSON.stringify(xWeight)));
}
}
/**
* Standardize class weighting objects.
*
* This function takes a single class-weighting object, an array of them,
* or a map from output name to class-weighting object. It compares it to the
* output name(s) of the model, base on which it outputs an array of
* class-weighting objects of which the length matches the number of outputs.
*
* @param classWeight Input class-weighting object(s).
* @param outputNames All output name(s) of the model.
* @return An array of class-weighting objects. The length of the array matches
* the model's number of outputs.
*/
function standardizeClassWeights(classWeight, outputNames) {
return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
}
function standardizeSampleWeights(classWeight, outputNames) {
return standardizeSampleOrClassWeights(classWeight, outputNames, 'sampleWeight');
}
/**
* Standardize by-sample and/or by-class weights for training.
*
* Note that this function operates on one model output at a time. For a model
* with multiple outputs, you must call this function multiple times.
*
* @param y The target tensor that the by-sample and/or by-class weight is for.
* The values of y are assumed to encode the classes, either directly
* as an integer index, or as one-hot encoding.
* @param sampleWeight By-sample weights.
* @param classWeight By-class weights: an object mapping class indices
* (integers) to a weight (float) to apply to the model's loss for the
* samples from this class during training. This can be useful to tell the
* model to "pay more attention" to samples from an under-represented class.
* @param sampleWeightMode The mode for the sample weights.
* @return A Promise of weight tensor, of which the size of the first dimension
* matches that of `y`.
*/
function standardizeWeights(_x, _x2, _x3, _x4) {
return _standardizeWeights.apply(this, arguments);
}
/**
* Apply per-sample weights on the loss values from a number of samples.
*
* @param losses Loss tensor of shape `[batchSize]`.
* @param sampleWeights Per-sample weight tensor of shape `[batchSize]`.
* @returns Tensor of the same shape as`losses`.
*/
function _standardizeWeights() {
_standardizeWeights = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(y, sampleWeight, classWeight, sampleWeightMode) {
var yClasses, yClassIndices, classSampleWeight;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(sampleWeight != null || sampleWeightMode != null)) {
_context.next = 2;
break;
}
throw new Error('Support sampleWeight is not implemented yet');
case 2:
if (!(classWeight != null)) {
_context.next = 15;
break;
}
// Apply class weights per sample.
yClasses = tidy(function () {
if (y.shape.length === 1) {
// Assume class indices.
return clone(y);
} else if (y.shape.length === 2) {
if (y.shape[1] > 1) {
// Assume one-hot encoding of classes.
var axis = 1;
return argMax(y, axis);
} else if (y.shape[1] === 1) {
// Class index.
return reshape(y, [y.shape[0]]);
} else {
throw new Error("Encountered unexpected last-dimension size (" + y.shape[1] + ") " + "during handling of class weights. The size is expected to be " + ">= 1.");
}
} else {
throw new Error("Unexpected rank of target (y) tensor (" + y.rank + ") during " + "handling of class weights. The rank is expected to be 1 or 2.");
}
});
_context.t0 = Array;
_context.next = 7;
return yClasses.data();
case 7:
_context.t1 = _context.sent;
yClassIndices = _context.t0.from.call(_context.t0, _context.t1);
dispose(yClasses);
classSampleWeight = [];
yClassIndices.forEach(function (classIndex) {
if (classWeight[classIndex] == null) {
throw new Error("classWeight must contain all classes in the training data. " + ("The class " + classIndex + " exists in the data but not in ") + "classWeight");
} else {
classSampleWeight.push(classWeight[classIndex]);
}
});
return _context.abrupt("return", tensor1d(classSampleWeight, 'float32'));
case 15:
return _context.abrupt("return", null);
case 16:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _standardizeWeights.apply(this, arguments);
}
function computeWeightedLoss$1(losses, sampleWeights) {
return mul(losses, sampleWeights);
}
var DEFAULT_VALIDATION_BATCH_SIZE = 32;
/**
* Standardize the output of a dataset iterator for use by
* LayersModel.fitDataset().
*
* @param model: A `tf.LayersModel` object.
* @param iteratorOut The output of a dataset iterator. It is required to be
* an object of the form `{xs: TensorOrArrayOrMap, ys:
* TensorOrArrayOrMap}`, where `TensorOrArrayOrMap` is a single `tf.Tensor`,
* a `tf.Tensor[]`, or a flat map from string names to `tf.Tensor`s.
* @returns A flat array of `tf.Tensor` objects: the input `tf.Tensor`s
* followed by the target `tf.Tensor`s. When `tf.Tensor`s are provided
* as a map, the order in the resulting array is taken from the `inputNames`
* and `outputNames` of the model.
*/
function standardizeDataIteratorOutput( // Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model, iteratorOut) {
var xs;
var ys;
var iteratorOutObj = iteratorOut;
xs = iteratorOutObj['xs'];
ys = iteratorOutObj['ys'];
assert(xs != null && ys != null, function () {
return 'A Dataset iterator for fitDataset() is expected to generate ' + 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' + 'values may be `tf.Tensor`, an array of Tensors, or a map of ' + 'string to Tensor. The provided Dataset instead generates ' + ("" + iteratorOut);
});
var flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
var flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
var batchSize = flattenedXs[0].shape[0];
assert(flattenedXs.length === model.inputs.length, function () {
return "LayersModel has " + model.inputs.length + " inputs, but the dataset " + ("provides " + flattenedXs.length + " inputs. (Expected input keys: ") + (JSON.stringify(model.inputNames) + ")");
});
assert(flattenedYs.length === model.outputs.length, function () {
return "LayersModel has " + model.outputs.length + " outputs, but the dataset " + ("provides " + flattenedYs.length + " outputs. (Expected output keys: ") + (JSON.stringify(model.outputNames) + ")");
});
var _loop = function _loop(xIndex) {
assert(flattenedXs[xIndex].shape[0] === batchSize, function () {
return "Batch size mismatch: input " + (model.inputNames[xIndex] + " has " + flattenedXs[xIndex].shape[0] + "; ") + ("expected " + batchSize + " based on input " + model.inputNames[0] + ".");
});
};
for (var xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
_loop(xIndex);
}
var _loop2 = function _loop2(yIndex) {
assert(flattenedYs[yIndex].shape[0] === batchSize, function () {
return "Batch size mismatch: output " + (model.outputNames[yIndex] + " has " + flattenedYs[yIndex].shape[0] + "; ") + ("expected " + batchSize + " based on input " + model.inputNames[0] + ".");
});
};
for (var yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
_loop2(yIndex);
}
return {
xs: flattenedXs,
ys: flattenedYs
};
}
function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
if (values instanceof Tensor) {
return [values];
} else if (Array.isArray(values)) {
assert(values.length === names.length, function () {
return "Received an array of " + values.length + " Tensors, but expected " + names.length + " to match the " + inputOrOutput + " keys " + names + ".";
});
return values;
} else {
var result = []; // Check that all the required keys are available.
for (var _iterator = _createForOfIteratorHelperLoose(names), _step; !(_step = _iterator()).done;) {
var name = _step.value;
if (values[name] == null) {
throw new ValueError("The feature data generated by the dataset lacks the required " + (inputOrOutput + " key '" + name + "'."));
}
result.push(values[name]);
}
return result;
}
}
function standardizeTensorValidationData(data) {
if (data.length === 3) {
throw new NotImplementedError('Validation with sample weights is not implemented yet.');
}
return {
xs: data[0],
ys: data[1]
};
}
function fitDataset(_x, _x2, _x3) {
return _fitDataset.apply(this, arguments);
}
/** Helper function that determines number of steps (batches) per epoch. */
function _fitDataset() {
_fitDataset = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee( // Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model, dataset, args) {
var hasBatchesPerEpoch, doValidation, valXs, valYs, validationData, trainFunction, outLabels, callbackMetrics, callbacks, verbose, _configureCallbacks, callbackList, history, epoch, dataIterator, epochLogs, stepsDone, batchIndex, iteratorOut, _standardizeDataItera, xs, ys, batchLogs, sampleWeights, standardClassWeights, i, ins, outs, _i, label, out, valOuts, _i2;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
hasBatchesPerEpoch = args.batchesPerEpoch != null;
assert(model.optimizer != null, function () {
return 'You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileConfig).';
});
assert(args != null, function () {
return "For fitDataset(), the 2nd argument (config) is required, " + "but it is not provided in this call.";
});
assert(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), function () {
return "For fitDataset(), config.epochs is expected to be a positive " + ("integer, but got " + args.epochs);
});
assert(!hasBatchesPerEpoch || args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch), function () {
return "For fitDataset(), config.batchesPerEpoch is expected to be a " + ("positive integer if specified, but got " + args.batchesPerEpoch);
});
assert( // tslint:disable-next-line:no-any
args['validationSplit'] == null, function () {
return '`validationSplit` is not supported by `fitDataset()`. ' + 'Use validationData instead.';
});
if (!model.isTraining) {
_context.next = 8;
break;
}
throw new Error('Cannot start training because another fit() call is ongoing.');
case 8:
model.isTraining = true;
_context.prev = 9;
doValidation = args.validationData != null;
if (doValidation) {
if (isDatasetObject(args.validationData)) {
assert(args.validationBatches == null || args.validationBatches > 0 && Number.isInteger(args.validationBatches), function () {
return "For fitDataset() with dataset-based validation, " + "config.validationBatches is expected not to be provided, " + "or to be a positive integer, " + ("but got " + args.validationBatches);
});
} else {
validationData = standardizeTensorValidationData(args.validationData);
valXs = validationData.xs;
valYs = validationData.ys;
}
}
trainFunction = model.makeTrainFunction();
outLabels = model.getDedupedMetricsNames();
if (doValidation) {
callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) {
return 'val_' + n;
}));
} else {
callbackMetrics = outLabels.slice();
}
callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
verbose = args.verbose == null ? 1 : args.verbose;
_configureCallbacks = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, // Batch size determined by the dataset itself.
doValidation, callbackMetrics), callbackList = _configureCallbacks.callbackList, history = _configureCallbacks.history;
callbackList.setModel(model);
model.history = history;
_context.next = 22;
return callbackList.onTrainBegin();
case 22:
model.stopTraining_ = false;
epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
_context.next = 26;
return dataset.iterator();
case 26:
dataIterator = _context.sent;
case 27:
if (!(epoch < args.epochs)) {
_context.next = 98;
break;
}
epochLogs = {};
_context.next = 31;
return callbackList.onEpochBegin(epoch);
case 31:
stepsDone = 0;
batchIndex = 0;
if (hasBatchesPerEpoch) {
_context.next = 37;
break;
}
_context.next = 36;
return dataset.iterator();
case 36:
dataIterator = _context.sent;
case 37:
if (!(hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true)) {
_context.next = 91;
break;
}
_context.next = 40;
return dataIterator.next();
case 40:
iteratorOut = _context.sent;
if (!(hasBatchesPerEpoch && iteratorOut.done)) {
_context.next = 44;
break;
}
console.warn('You provided `batchesPerEpoch` as ' + (args.batchesPerEpoch + ", ") + 'but your dataset iterator ran out of data after ' + (stepsDone + " batches; ") + 'interrupting training. Make sure that your ' + 'dataset can generate at least `batchesPerEpoch * epochs` ' + 'batches (in this case, ' + (args.batchesPerEpoch * args.epochs + " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.');
return _context.abrupt("break", 91);
case 44:
if (!(iteratorOut.value != null)) {
_context.next = 73;
break;
}
_standardizeDataItera = standardizeDataIteratorOutput(model, iteratorOut.value), xs = _standardizeDataItera.xs, ys = _standardizeDataItera.ys;
batchLogs = {};
batchLogs['batch'] = batchIndex;
batchLogs['size'] = xs[0].shape[0];
_context.next = 51;
return callbackList.onBatchBegin(batchIndex, batchLogs);
case 51:
sampleWeights = [];
if (!(args.classWeight != null)) {
_context.next = 64;
break;
}
standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
i = 0;
case 55:
if (!(i < standardClassWeights.length)) {
_context.next = 64;
break;
}
_context.t0 = sampleWeights;
_context.next = 59;
return standardizeWeights(ys[i], null, standardClassWeights[i]);
case 59:
_context.t1 = _context.sent;
_context.t0.push.call(_context.t0, _context.t1);
case 61:
++i;
_context.next = 55;
break;
case 64:
// Train on batch.
ins = xs.concat(ys).concat(sampleWeights);
outs = trainFunction(ins);
dispose(ins);
for (_i = 0; _i < outLabels.length; ++_i) {
label = outLabels[_i];
out = outs[_i];
batchLogs[label] = out;
keep(out);
}
_context.next = 70;
return callbackList.onBatchEnd(batchIndex, batchLogs);
case 70:
disposeTensorsInLogs(batchLogs);
batchIndex++;
stepsDone++;
case 73:
if (!(hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch : iteratorOut.done)) {
_context.next = 87;
break;
}
if (!doValidation) {
_context.next = 86;
break;
}
valOuts = void 0;
if (!isDatasetObject(args.validationData)) {
_context.next = 84;
break;
}
_context.t2 = toList;
_context.next = 80;
return model.evaluateDataset(args.validationData, {
batches: args.validationBatches
});
case 80:
_context.t3 = _context.sent;
valOuts = (0, _context.t2)(_context.t3);
_context.next = 85;
break;
case 84:
valOuts = toList(model.evaluate(valXs, valYs, {
batchSize: args.validationBatchSize == null ? DEFAULT_VALIDATION_BATCH_SIZE : args.validationBatchSize,
verbose: 0
}));
case 85:
for (_i2 = 0; _i2 < model.metricsNames.length; ++_i2) {
epochLogs["val_" + model.metricsNames[_i2]] = valOuts[_i2];
}
case 86:
return _context.abrupt("break", 91);
case 87:
if (!model.stopTraining_) {
_context.next = 89;
break;
}
return _context.abrupt("break", 91);
case 89:
_context.next = 37;
break;
case 91:
_context.next = 93;
return callbackList.onEpochEnd(epoch, epochLogs);
case 93:
epoch++;
if (!model.stopTraining_) {
_context.next = 96;
break;
}
return _context.abrupt("break", 98);
case 96:
_context.next = 27;
break;
case 98:
_context.next = 100;
return callbackList.onTrainEnd();
case 100:
_context.next = 102;
return model.history.syncData();
case 102:
return _context.abrupt("return", model.history);
case 103:
_context.prev = 103;
model.isTraining = false;
return _context.finish(103);
case 106:
case "end":
return _context.stop();
}
}
}, _callee, null, [[9,, 103, 106]]);
}));
return _fitDataset.apply(this, arguments);
}
function getStepsPerEpoch(dataset, args) {
// Attempt to determine # of batches in an epoch.
var stepsPerEpoch = null;
if (args.batchesPerEpoch != null) {
stepsPerEpoch = args.batchesPerEpoch;
} else if (Number.isFinite(dataset.size)) {
stepsPerEpoch = dataset.size;
}
return stepsPerEpoch;
} // Check if provided object is a Dataset object by checking its .iterator
// element.
function isDatasetObject(dataset) {
return typeof dataset.iterator === 'function';
} // Check if provided object is a LazyIterator object by checking it's .next
// element.
function isLazyIteratorObject(iterator) {
return typeof iterator.next === 'function';
}
function evaluateDataset(_x4, _x5, _x6) {
return _evaluateDataset.apply(this, arguments);
}
function _evaluateDataset() {
_evaluateDataset = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2( // Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model, dataset, args) {
var hasBatches, f, outs, dataIterator, numExamples, batch, _loop3, _ret, i, oldScalar;
return regeneratorRuntime.wrap(function _callee2$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
args = args || {};
hasBatches = args.batches != null;
f = model.testFunction;
outs = [];
if (!(args.verbose > 0)) {
_context3.next = 6;
break;
}
throw new NotImplementedError('Verbose mode is not implemented yet.');
case 6:
assert(!hasBatches || args.batches > 0 && Number.isInteger(args.batches), function () {
return 'Test loop expects `batches` to be a positive integer, but ' + ("received " + JSON.stringify(args.batches));
});
if (!isLazyIteratorObject(dataset)) {
_context3.next = 11;
break;
}
_context3.t0 = dataset;
_context3.next = 14;
break;
case 11:
_context3.next = 13;
return dataset.iterator();
case 13:
_context3.t0 = _context3.sent;
case 14:
dataIterator = _context3.t0;
// Keeps track of number of examples used in this evaluation.
numExamples = 0;
batch = 0;
_loop3 = /*#__PURE__*/regeneratorRuntime.mark(function _loop3() {
var iteratorOut;
return regeneratorRuntime.wrap(function _loop3$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return dataIterator.next();
case 2:
iteratorOut = _context2.sent;
outs = tidy(function () {
if (iteratorOut.value) {
(function () {
// TODO(cais): Once real dataset is available, use
// `map(x => standardizeDataIteratorOutput(model, x).map(f)`.
var _standardizeDataItera2 = standardizeDataIteratorOutput(model, iteratorOut.value),
xs = _standardizeDataItera2.xs,
ys = _standardizeDataItera2.ys;
var xsAndYs = xs.concat(ys);
var batchOuts = tidy(function () {
return f(xsAndYs);
});
dispose(xsAndYs);
if (batch === 0) {
for (var _i3 = 0; _i3 < batchOuts.length; ++_i3) {
outs.push(scalar(0));
}
}
var batchSize = xsAndYs[0].shape[0];
var _loop4 = function _loop4(_i4) {
var batchOut = batchOuts[_i4];
var oldScalar = outs[_i4];
outs[_i4] = tidy(function () {
return add$1(outs[_i4], mul(batchSize, batchOut));
});
if (batch > 0) {
dispose(oldScalar);
}
};
for (var _i4 = 0; _i4 < batchOuts.length; ++_i4) {
_loop4(_i4);
}
dispose(batchOuts);
numExamples += batchSize;
++batch;
})();
}
return outs;
});
if (!iteratorOut.done) {
_context2.next = 7;
break;
}
if (hasBatches) {
console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' + 'Interrupting evalution. Make sure that your ' + 'dataset can generate at least `batches` ' + ("batches (in this case, " + args.batches + " batches). ") + 'You may need to use the repeat() function when building ' + 'your dataset.');
}
return _context2.abrupt("return", "break");
case 7:
case "end":
return _context2.stop();
}
}
}, _loop3);
});
case 18:
if (!(hasBatches ? batch < args.batches : true)) {
_context3.next = 25;
break;
}
return _context3.delegateYield(_loop3(), "t1", 20);
case 20:
_ret = _context3.t1;
if (!(_ret === "break")) {
_context3.next = 23;
break;
}
return _context3.abrupt("break", 25);
case 23:
_context3.next = 18;
break;
case 25:
for (i = 0; i < outs.length; ++i) {
oldScalar = outs[i];
outs[i] = div(outs[i], numExamples);
dispose(oldScalar);
}
return _context3.abrupt("return", singletonOrArray(outs));
case 27:
case "end":
return _context3.stop();
}
}
}, _callee2);
}));
return _evaluateDataset.apply(this, arguments);
}
function checkBatchSize(batchSize) {
assert(batchSize > 0 && Number.isInteger(batchSize), function () {
return "batchSize is required to be a positive integer, but got " + batchSize;
});
}
/**
* Slice a Tensor or an Array of Tensors, by start and stop indices.
*
* Porting Note: The `_slice_arrays` function in PyKeras is covered by this
* function and `sliceArraysByIndices()` together.
*
* @param arrays: the input.
* @param start: the starting index (inclusive).
* @param stop: the stopping index (exclusive).
* @returns The result of the slicing. If `arrays` is an `Array` of
* `tf.Tensor`s, the slicing will be applied to all elements of the `Array`
* in the same way.
*/
function sliceArrays(arrays, start, stop) {
if (arrays == null) {
return [null];
} else if (Array.isArray(arrays)) {
return arrays.map(function (array) {
return sliceAlongFirstAxis(array, start, stop - start);
});
} else {
// Tensor.
return sliceAlongFirstAxis(arrays, start, stop - start);
}
}
/**
* Slice a Tensor or an Array of Tensors, by random-order indices.
*
* Porting Note: The `_slice_arrays` function in PyKeras is covered by this
* function and `sliceArrays()` together.
*
* @param arrays The input `tf.Tensor` or `Array` of `tf.Tensor`s to slice.
* If an `Array` of `tf.Tensor`s, all `tf.Tensor`s will be sliced in the
* same fashion.
* @param indices The indices to use for slicing along the first (batch)
* dimension.
* @returns Result(s) of the slicing.
*/
function sliceArraysByIndices(arrays, indices) {
return tidy(function () {
if (arrays == null) {
return null;
} else if (Array.isArray(arrays)) {
return arrays.map(function (array) {
return sliceArraysByIndices(array, indices);
});
} else {
// TODO(cais): indices should be a pre-constructed Tensor1D to avoid
// tensor1d() calls.
return gather$1(arrays, indices.dtype === 'int32' ? indices : cast(indices, 'int32'));
}
});
}
/**
* Returns a list of batch indices (tuples of indices).
* @param size: Integer, total size of the data to slice into batches.
* @param batchSize: Integer, batch size.
* @returns An Array of [batchStart, batchEnd] tuples. batchStart is
* inclusive; batchEnd is exclusive. I.e., each batch consists of indices x
* that satisfy batchStart <= x < batchEnd.
*/
function makeBatches(size, batchSize) {
var output = [];
var batchStart = 0;
var batchEnd = null;
while (batchStart < size) {
batchEnd = batchStart + batchSize;
if (batchEnd >= size) {
batchEnd = size;
}
output.push([batchStart, batchEnd]);
batchStart = batchEnd;
}
return output;
}
/**
* Abstract fit function for `f(ins)`.
* @param f A Function returning a list of tensors. For training, this
* function is expected to perform the updates to the variables.
* @param ins List of tensors to be fed to `f`.
* @param outLabels List of strings, display names of the outputs of `f`.
* @param batchSize Integer batch size or `== null` if unknown. Default : 32.
* @param epochs Number of times to iterate over the data. Default : 1.
* @param verbose Verbosity mode: 0, 1, or 2. Default: 1.
* @param callbacks List of callbacks to be called during training.
* @param valF Function to call for validation.
* @param valIns List of tensors to be fed to `valF`.
* @param shuffle Whether to shuffle the data at the beginning of every
* epoch. Default : true.
* @param callbackMetrics List of strings, the display names of the metrics
* passed to the callbacks. They should be the concatenation of the
* display names of the outputs of `f` and the list of display names
* of the outputs of `valF`.
* @param initialEpoch Epoch at which to start training (useful for
* resuming a previous training run). Default : 0.
* @param stepsPerEpoch Total number of steps (batches on samples) before
* declaring one epoch finished and starting the next epoch. Ignored with
* the default value of `undefined` or `null`.
* @param validationSteps Number of steps to run validation for (only if
* doing validation from data tensors). Not applicable for tfjs-layers.
* @returns A `History` object.
*/
function fitLoop(_x, _x2, _x3, _x4, _x5, _x6, _x7, _x8, _x9, _x10, _x11, _x12, _x13, _x14, _x15) {
return _fitLoop.apply(this, arguments);
}
function _fitLoop() {
_fitLoop = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2( // Type `model` as `any` here to avoid circular dependency w/ training.ts.
// tslint:disable-next-line:no-any
model, f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
var doValidation, numTrainSamples, indexArray, _configureCallbacks, callbackList, history, _loop, epoch, _ret;
return regeneratorRuntime.wrap(function _callee2$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
if (batchSize == null) {
batchSize = 32;
}
if (epochs == null) {
epochs = 1;
}
if (shuffle$1 == null) {
shuffle$1 = true;
}
if (initialEpoch == null) {
initialEpoch = 0;
} // TODO(cais): Change const to let below when implementing validation.
doValidation = false;
if (valF != null && valIns != null) {
doValidation = true; // TODO(cais): verbose message.
}
if (!(validationSteps != null)) {
_context4.next = 10;
break;
}
doValidation = true;
if (!(stepsPerEpoch == null)) {
_context4.next = 10;
break;
}
throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' + 'i.e., `stepsPerEpoch` must be set.');
case 10:
numTrainSamples = model.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
if (numTrainSamples != null) {
indexArray = range$1(0, numTrainSamples);
}
if (verbose == null) {
verbose = 1;
}
_configureCallbacks = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics), callbackList = _configureCallbacks.callbackList, history = _configureCallbacks.history;
callbackList.setModel(model);
model.history = history;
_context4.next = 18;
return callbackList.onTrainBegin();
case 18:
model.stopTraining_ = false; // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
// TODO(cais): Pre-convert feeds for performance as in PyKeras.
_loop = /*#__PURE__*/regeneratorRuntime.mark(function _loop(epoch) {
var epochLogs;
return regeneratorRuntime.wrap(function _loop$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
_context3.next = 2;
return callbackList.onEpochBegin(epoch);
case 2:
epochLogs = {};
if (!(stepsPerEpoch != null)) {
_context3.next = 7;
break;
}
throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
case 7:
return _context3.delegateYield( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var epochIndexArray1D, batches, _loop2, batchIndex, _ret2;
return regeneratorRuntime.wrap(function _callee$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (!(shuffle$1 === 'batch')) {
_context2.next = 4;
break;
}
throw new NotImplementedError('batch shuffling is not implemneted yet');
case 4:
if (shuffle$1) {
shuffle(indexArray);
}
case 5:
// Convert the potentially shuffled indices to Tensor1D, to avoid the
// cost of repeated creation of Array1Ds later on.
epochIndexArray1D = tensor1d(indexArray);
batches = makeBatches(numTrainSamples, batchSize);
_loop2 = /*#__PURE__*/regeneratorRuntime.mark(function _loop2(batchIndex) {
var batchLogs;
return regeneratorRuntime.wrap(function _loop2$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
batchLogs = {};
_context.next = 3;
return callbackList.onBatchBegin(batchIndex, batchLogs);
case 3:
tidy(function () {
var batchStart = batches[batchIndex][0];
var batchEnd = batches[batchIndex][1];
var batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
batchLogs['batch'] = batchIndex;
batchLogs['size'] = batchEnd - batchStart; // TODO(cais): In ins, train flag can be a number, instead of an
// Tensor? Do we need to handle this in tfjs-layers?
var insBatch = sliceArraysByIndices(ins, batchIds);
var outs = f(insBatch);
for (var i = 0; i < outLabels.length; ++i) {
var label = outLabels[i];
var out = outs[i];
batchLogs[label] = out;
keep(out); // TODO(cais): Use scope() to avoid ownership.
}
if (batchIndex === batches.length - 1) {
// Last batch.
if (doValidation) {
var valOuts = model.testLoop(valF, valIns, batchSize); // Porting Notes: In tfjs-layers, valOuts is always an Array.
for (var _i = 0; _i < outLabels.length; ++_i) {
var _label = outLabels[_i];
var _out = valOuts[_i];
keep(_out); // TODO(cais): Use scope() to avoid ownership.
epochLogs['val_' + _label] = _out;
}
}
}
});
_context.next = 6;
return callbackList.onBatchEnd(batchIndex, batchLogs);
case 6:
disposeTensorsInLogs(batchLogs);
if (!model.stopTraining_) {
_context.next = 9;
break;
}
return _context.abrupt("return", "break");
case 9:
case "end":
return _context.stop();
}
}
}, _loop2);
});
batchIndex = 0;
case 9:
if (!(batchIndex < batches.length)) {
_context2.next = 17;
break;
}
return _context2.delegateYield(_loop2(batchIndex), "t0", 11);
case 11:
_ret2 = _context2.t0;
if (!(_ret2 === "break")) {
_context2.next = 14;
break;
}
return _context2.abrupt("break", 17);
case 14:
++batchIndex;
_context2.next = 9;
break;
case 17:
epochIndexArray1D.dispose();
case 18:
case "end":
return _context2.stop();
}
}
}, _callee);
})(), "t0", 8);
case 8:
_context3.next = 10;
return callbackList.onEpochEnd(epoch, epochLogs);
case 10:
if (!model.stopTraining_) {
_context3.next = 12;
break;
}
return _context3.abrupt("return", "break");
case 12:
case "end":
return _context3.stop();
}
}
}, _loop);
});
epoch = initialEpoch;
case 21:
if (!(epoch < epochs)) {
_context4.next = 29;
break;
}
return _context4.delegateYield(_loop(epoch), "t0", 23);
case 23:
_ret = _context4.t0;
if (!(_ret === "break")) {
_context4.next = 26;
break;
}
return _context4.abrupt("break", 29);
case 26:
++epoch;
_context4.next = 21;
break;
case 29:
_context4.next = 31;
return callbackList.onTrainEnd();
case 31:
_context4.next = 33;
return model.history.syncData();
case 33:
return _context4.abrupt("return", model.history);
case 34:
case "end":
return _context4.stop();
}
}
}, _callee2);
}));
return _fitLoop.apply(this, arguments);
}
function fitTensors(_x16, _x17, _x18, _x19) {
return _fitTensors.apply(this, arguments);
}
/**
* Ensure tensors all have a rank of at least 2.
*
* If a tensor has a rank of 1, it is dimension-expanded to rank 2.
* If any tensor has a rank of 0 (i.e., is a scalar), an error will be thrown.
*/
function _fitTensors() {
_fitTensors = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3( // Type `model` as `any` here to avoid circular dependency w/ training.ts.
// tslint:disable-next-line:no-any
model, x, y, args) {
var inputs, targets, inputValX, inputValY, valX, valY, sampleWeights, batchSize, checkBatchAxis, standardizedOuts, doValidation, valIns, _checkBatchAxis, valStandardized, splitAt, originalBatchSize, ins, trainFunction, outLabels, valFunction, callbackMetrics, callbacks, out;
return regeneratorRuntime.wrap(function _callee3$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
if (args === void 0) {
args = {};
}
if (!model.isTraining) {
_context5.next = 3;
break;
}
throw new Error('Cannot start training because another fit() call is ongoing.');
case 3:
model.isTraining = true;
_context5.prev = 4;
batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize); // Validate user data.
// TODO(cais): Support sampleWeight.
checkBatchAxis = false;
_context5.next = 10;
return model.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
case 10:
standardizedOuts = _context5.sent;
inputs = standardizedOuts[0];
targets = standardizedOuts[1];
sampleWeights = standardizedOuts[2]; // Prepare validation data.
doValidation = false;
if (!(args.validationData != null && args.validationData.length > 0)) {
_context5.next = 36;
break;
}
doValidation = true;
if (!(args.validationData.length === 2)) {
_context5.next = 22;
break;
}
// config.validationData consists of valX and valY.
inputValX = args.validationData[0];
inputValY = args.validationData[1];
_context5.next = 27;
break;
case 22:
if (!(args.validationData.length === 3)) {
_context5.next = 26;
break;
}
throw new NotImplementedError('validationData including sample weights is not supported yet.');
case 26:
throw new ValueError("When passing validation data, it must contain 2 (valX, valY) " + "or 3 (valX, valY, valSampleWeight) items; " + (args.validationData + " is invalid."));
case 27:
_checkBatchAxis = true;
_context5.next = 30;
return model.standardizeUserData(inputValX, inputValY, null,
/** Unused sample weights. */
null,
/** Unused class weights. */
_checkBatchAxis, batchSize);
case 30:
valStandardized = _context5.sent;
valX = valStandardized[0];
valY = valStandardized[1];
valIns = valX.concat(valY); // TODO(cais): Add useLearningPhase data properly.
_context5.next = 37;
break;
case 36:
if (args.validationSplit != null && args.validationSplit > 0 && args.validationSplit < 1) {
doValidation = true; // Porting Note: In tfjs-layers, inputs[0] is always a Tensor.
splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
originalBatchSize = inputs[0].shape[0];
valX = sliceArrays(inputs, splitAt, originalBatchSize);
inputs = sliceArrays(inputs, 0, splitAt);
valY = sliceArrays(targets, splitAt, originalBatchSize);
targets = sliceArrays(targets, 0, splitAt); // TODO(cais): Once sampleWeights becomes available, slice it to get
// valSampleWeights.
valIns = valX.concat(valY); // TODO(cais): Add useLearningPhase data properly.
} else if (args.validationSteps != null) {
doValidation = true; // TODO(cais): Add useLearningPhase.
}
case 37:
ins = inputs.concat(targets).concat(sampleWeights);
model.checkTrainableWeightsConsistency(); // TODO(cais): Handle use_learning_phase and learning_phase?
// Porting Note: Here we see a key deviation of tfjs-layers from
// Keras.
// Due to the imperative nature of tfjs-layers' backend (tfjs-core),
// we do not construct symbolic computation graphs to embody the
// training process. Instead, we define a function that performs the
// training action. In PyKeras, the data (inputs and targets) are fed
// through graph placeholders. In tfjs-layers, the data are fed as
// function arguments. Since the function are defined below in the
// scope, we don't have equivalents of PyKeras's
// `_make_train_funciton`.
trainFunction = model.makeTrainFunction();
outLabels = model.getDedupedMetricsNames();
if (doValidation) {
model.makeTestFunction();
valFunction = model.testFunction;
callbackMetrics = outLabels.slice().concat(outLabels.map(function (n) {
return 'val_' + n;
}));
} else {
valFunction = null;
valIns = [];
callbackMetrics = outLabels.slice();
}
callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
_context5.next = 45;
return fitLoop(model, trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
case 45:
out = _context5.sent;
return _context5.abrupt("return", out);
case 47:
_context5.prev = 47;
model.isTraining = false; // Memory clean up.
disposeNewTensors(inputs, x);
disposeNewTensors(targets, y);
disposeNewTensors(valX, inputValX);
disposeNewTensors(valY, inputValY);
if (sampleWeights != null) {
dispose(sampleWeights);
}
return _context5.finish(47);
case 55:
case "end":
return _context5.stop();
}
}
}, _callee3, null, [[4,, 47, 55]]);
}));
return _fitTensors.apply(this, arguments);
}
function ensureTensorsRank2OrHigher(tensors) {
var outs = [];
if (tensors instanceof Tensor) {
tensors = [tensors];
} // Make Tensors at least 2D.
for (var i = 0; i < tensors.length; ++i) {
var tensor = tensors[i];
if (tensor.rank === 1) {
outs.push(expandDims$1(tensor, 1));
} else if (tensor.rank === 0) {
throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).');
} else {
outs.push(tensor);
}
}
return outs;
}
/**
* Compare a set of tensors with a reference (old) set, discard the ones
* in the new set that are not present in the reference set.
*
* This method is used for memory clenaup during calls such as
* LayersModel.fit().
*
* @param tensors New set which may contain Tensors not present in
* `refTensors`.
* @param refTensors Reference Tensor set.
*/
// TODO(cais, kangyizhang): Deduplicate with tfjs-data.
function disposeNewTensors(tensors, refTensors) {
if (tensors == null) {
return;
}
var oldTensorIds = [];
if (refTensors instanceof Tensor) {
oldTensorIds.push(refTensors.id);
} else if (Array.isArray(refTensors)) {
refTensors.forEach(function (t) {
return oldTensorIds.push(t.id);
});
} else if (refTensors != null) {
// `oldTensors` is a map from string name to Tensor.
for (var name in refTensors) {
var oldTensor = refTensors[name];
oldTensorIds.push(oldTensor.id);
}
}
var tensorsToDispose = [];
if (tensors instanceof Tensor) {
if (oldTensorIds.indexOf(tensors.id) === -1) {
tensorsToDispose.push(tensors);
}
} else if (Array.isArray(tensors)) {
tensors.forEach(function (t) {
if (oldTensorIds.indexOf(t.id) === -1) {
tensorsToDispose.push(t);
}
});
} else if (tensors != null) {
// `oldTensors` is a map from string name to Tensor.
for (var _name in tensors) {
var tensor = tensors[_name];
if (oldTensorIds.indexOf(tensor.id) === -1) {
tensorsToDispose.push(tensor);
}
}
}
tensorsToDispose.forEach(function (t) {
if (!t.isDisposed) {
t.dispose();
}
});
}
/**
* Helper function for polymorphic input data: 1. singleton Tensor.
*/
function isDataTensor(x) {
return x instanceof Tensor;
}
/**
* Helper function for polymorphic input data: 2. Array of Tensor.
*/
function isDataArray(x) {
return Array.isArray(x);
}
/**
* Helper function for polymorphic input data: 3. "dict" of Tensor.
*/
function isDataDict(x) {
return !isDataTensor(x) && !isDataArray(x);
}
/**
* Normalizes inputs and targets provided by users.
* @param data User-provided input data (polymorphic).
* @param names An Array of expected Tensor names.
* @param shapes Optional Array of expected Tensor shapes.
* @param checkBatchAxis Whether to check that the batch axis of the arrays
* match the expected value found in `shapes`.
* @param exceptionPrefix String prefix used for exception formatting.
* @returns List of standardized input Tensors (one Tensor per model input).
* @throws ValueError: in case of improperly formatted user data.
*/
function standardizeInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
if (exceptionPrefix === void 0) {
exceptionPrefix = '';
}
if (names == null || names.length === 0) {
// Check for the case where the model expected no data, but some data got
// sent.
if (data != null) {
var gotUnexpectedData = false;
if (isDataArray(data) && data.length > 0) {
gotUnexpectedData = true;
} else if (isDataDict(data)) {
for (var key in data) {
if (data.hasOwnProperty(key)) {
gotUnexpectedData = true;
break;
}
}
} else {
// `data` is a singleton Tensor in this case.
gotUnexpectedData = true;
}
if (gotUnexpectedData) {
throw new ValueError("Error when checking model " + exceptionPrefix + " expected no data, " + ("but got " + data));
}
}
return [];
}
if (data == null) {
return names.map(function (name) {
return null;
});
}
var arrays;
if (isDataDict(data)) {
data = data;
arrays = [];
for (var _iterator = _createForOfIteratorHelperLoose(names), _step; !(_step = _iterator()).done;) {
var name = _step.value;
if (data[name] == null) {
throw new ValueError("No data provided for \"" + name + "\". Need data for each key in: " + ("" + names));
}
arrays.push(data[name]);
}
} else if (isDataArray(data)) {
data = data;
if (data.length !== names.length) {
throw new ValueError("Error when checking model " + exceptionPrefix + ": the Array of " + "Tensors that you are passing to your model is not the size the " + ("model expected. Expected to see " + names.length + " Tensor(s), but ") + ("instead got the following list of Tensor(s): " + data));
}
arrays = data;
} else {
data = data;
if (names.length > 1) {
throw new ValueError("The model " + exceptionPrefix + " expects " + names.length + " Tensor(s), " + ("but only received one Tensor. Found: Tensor with shape " + data.shape));
}
arrays = [data];
}
arrays = ensureTensorsRank2OrHigher(arrays); // Check shape compatibility.
if (shapes != null) {
for (var i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
var array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + names[i] + " " + ("to have " + shapes[i].length + " dimension(s). but got array with ") + ("shape " + array.shape));
}
for (var j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
// Skip the first (batch) axis.
continue;
}
var dim = array.shape[j];
var refDim = shapes[i][j];
if (refDim != null && refDim >= 0 && dim !== refDim) {
throw new ValueError(exceptionPrefix + " expected a batch of elements where each " + ("example has shape [" + shapes[i].slice(1, shapes[i].length) + "] ") + ("(i.e.,tensor shape [*," + shapes[i].slice(1, shapes[i].length) + "])") + (" but the " + exceptionPrefix + " received an input with " + array.shape[0]) + (" examples, each with shape [" + array.shape.slice(1, array.shape.length) + "]") + (" (tensor shape [" + array.shape + "])"));
}
}
}
}
return arrays;
}
/**
* User input validation for Tensors.
* @param inputs `Array` of `tf.Tensor`s for inputs.
* @param targets `Array` of `tf.Tensor`s for targets.
* @param weights Optional `Array` of `tf.Tensor`s for sample weights.
* @throws ValueError: in case of incorrectly formatted data.
*/
function checkArrayLengths(inputs, targets, weights) {
var setX = unique$1(inputs.map(function (input) {
return input.shape[0];
}));
setX.sort();
var setY = unique$1(targets.map(function (target) {
return target.shape[0];
}));
setY.sort(); // TODO(cais): Check `weights` as well.
if (setX.length > 1) {
throw new ValueError("All input Tensors (x) should have the same number of samples. " + "Got array shapes: " + ("" + JSON.stringify(inputs.map(function (input) {
return input.shape;
}))));
}
if (setY.length > 1) {
throw new ValueError("All target Tensors (y) should have the same number of samples. " + "Got array shapes: " + ("" + JSON.stringify(targets.map(function (target) {
return target.shape;
}))));
}
if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
throw new ValueError("Input Tensors should have the same number of samples as target " + ("Tensors. Found " + setX[0] + " input sample(s) and " + setY[0] + " target ") + "sample(s).");
}
}
/**
* Validation on the compatibility of targes and loss functions.
*
* This helps prevent users from using loss functions incorrectly.
*
* @param targets `Array` of `tf.Tensor`s of targets.
* @param lossFns `Array` of loss functions.
* @param outputShapes `Array` of shapes of model outputs.
*/
function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
// TODO(cais): Dedicated test coverage?
var keyLosses = [meanSquaredError$1, binaryCrossentropy, categoricalCrossentropy];
for (var i = 0; i < targets.length; ++i) {
var y = targets[i];
var loss = lossFns[i];
var shape = outputShapes[i];
if (loss == null) {
continue;
}
if (loss === categoricalCrossentropy) {
if (y.shape[y.shape.length - 1] === 1) {
throw new ValueError("You are passing a target array of shape " + y.shape + " while using " + "a loss 'categorical_crossentropy'. 'categorical_crossentropy'" + "expects targets to be binary matrices (1s and 0s) of shape " + "[samples, classes]."); // TODO(cais): Example code in error message.
}
}
if (keyLosses.indexOf(loss) !== -1) {
var slicedYShape = y.shape.slice(1);
var slicedShape = shape.slice(1);
for (var j = 0; j < slicedYShape.length; ++j) {
var targetDim = slicedYShape[j];
var outDim = slicedShape[j];
if (outDim != null && targetDim !== outDim) {
throw new ValueError("A target Tensor with shape " + y.shape + " was passed for an " + ("output of shape " + shape + ", while using a loss function that ") + "expects targets to have the same shape as the output.");
}
}
}
}
}
/**
* Check inputs provided by the user.
*
* Porting Note: This corresponds to _standardize_input_data() in Python
* Keras. Because of the strong typing in TF.js, we do not need to convert
* the data. Specifically:
* 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
* example. We don't need to worry about that here because there is no
* widely popular javascript/typesdcript equivalent of pandas (so far).
* If one becomes available in the future, we can add support.
* 2) in PyKeras, inputs can be Python dict. But here we are stipulating
* that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
* may add support for `Object` data inputs in the future when the need
* arises.
*
* Instead, we perform basic checks for number of parameters and shapes.
*
* @param data: The input data.
* @param names: Name for the inputs, from the model.
* @param shapes: Expected shapes for the input data, from the model.
* @param checkBatchAxis: Whether the size along the batch axis (i.e., the
* first dimension) will be checked for matching.
* @param exceptionPrefix: Execption prefix message, used in generating error
* messages.
* @throws ValueError: on incorrect number of inputs or mismatches in shapes.
*/
function checkInputData(data, names, shapes, checkBatchAxis, exceptionPrefix) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
if (exceptionPrefix === void 0) {
exceptionPrefix = '';
}
var arrays;
if (Array.isArray(data)) {
if (data.length !== names.length) {
throw new ValueError("Error when checking model " + exceptionPrefix + ": the Array of " + "Tensors that you are passing to your model is not the size the " + ("the model expected. Expected to see " + names.length + " Tensor(s),") + (" but instead got " + data.length + " Tensors(s)."));
}
arrays = data;
} else {
if (names.length > 1) {
throw new ValueError("The model expects " + names.length + " " + exceptionPrefix + " Tensors, " + "but only received one Tensor. Found: array with shape " + (JSON.stringify(data.shape) + "."));
}
arrays = [data];
}
if (shapes != null) {
for (var i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
var array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + names[i] + " " + ("to have " + shapes[i].length + " dimension(s), but got array with ") + ("shape " + JSON.stringify(array.shape)));
}
for (var j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
var dim = array.shape[j];
var refDim = shapes[i][j];
if (refDim != null) {
if (refDim !== dim) {
throw new ValueError("Error when checking " + exceptionPrefix + ": expected " + (names[i] + " to have shape " + JSON.stringify(shapes[i]) + " but ") + ("got array with shape " + JSON.stringify(array.shape) + "."));
}
}
}
}
}
}
/**
* Maps metric functions to model outputs.
* @param metrics An shortcut strings name, metric function, `Array` or dict
* (`Object`) of metric functions.
* @param outputNames An `Array` of the names of model outputs.
* @returns An `Array` (one entry per model output) of `Array` of metric
* functions. For instance, if the model has 2 outputs, and for the first
* output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
* and just `binaryAccuracy` for the second output, the `Array` would look
* like:
* `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
* @throws TypeError: incompatible metrics format.
*/
function collectMetrics(metrics, outputNames) {
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
return outputNames.map(function (name) {
return [];
});
}
var wrappedMetrics;
if (typeof metrics === 'string' || typeof metrics === 'function') {
wrappedMetrics = [metrics];
} else if (Array.isArray(metrics) || typeof metrics === 'object') {
wrappedMetrics = metrics;
} else {
throw new TypeError('Type of metrics argument not understood. Expected an string,' + ("function, Array, or Object, found: " + metrics));
}
if (Array.isArray(wrappedMetrics)) {
// We then apply all metrics to all outputs.
return outputNames.map(function (name) {
return wrappedMetrics;
});
} else {
// In this case, metrics is a dict.
var nestedMetrics = [];
for (var _iterator2 = _createForOfIteratorHelperLoose(outputNames), _step2; !(_step2 = _iterator2()).done;) {
var name = _step2.value;
var outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
if (!Array.isArray(outputMetrics)) {
outputMetrics = [outputMetrics];
}
nestedMetrics.push(outputMetrics);
}
return nestedMetrics;
}
}
var LAYERS_MODEL_FORMAT_NAME = 'layers-model';
/**
* A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
* for training, evaluation, prediction and saving.
*
* `tf.LayersModel` is the basic unit of training, inference and evaluation in
* TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
*
* See also:
* `tf.Sequential`, `tf.loadLayersModel`.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
var LayersModel = /*#__PURE__*/function (_Container) {
_inheritsLoose(LayersModel, _Container);
function LayersModel(args) {
var _this;
_this = _Container.call(this, args) || this;
_this.isTraining = false;
return _this;
}
/**
* Print a text summary of the model's layers.
*
* The summary includes
* - Name and type of all layers that comprise the model.
* - Output shape(s) of the layers
* - Number of weight parameters of each layer
* - If the model has non-sequential-like topology, the inputs each layer
* receives
* - The total number of trainable and non-trainable parameters of the model.
*
* ```js
* const input1 = tf.input({shape: [10]});
* const input2 = tf.input({shape: [20]});
* const dense1 = tf.layers.dense({units: 4}).apply(input1);
* const dense2 = tf.layers.dense({units: 8}).apply(input2);
* const concat = tf.layers.concatenate().apply([dense1, dense2]);
* const output =
* tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
*
* const model = tf.model({inputs: [input1, input2], outputs: output});
* model.summary();
* ```
*
* @param lineLength Custom line length, in number of characters.
* @param positions Custom widths of each of the columns, as either
* fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
* of characters (e.g., `[30, 50, 65]`). Each number corresponds to
* right-most (i.e., ending) position of a column.
* @param printFn Custom print function. Can be used to replace the default
* `console.log`. For example, you can use `x => {}` to mute the printed
* messages in the console.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
var _proto = LayersModel.prototype;
_proto.summary = function summary(lineLength, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
if (!this.built) {
throw new ValueError("This model has never been called, thus its weights have not been " + "created yet. So no summary can be displayed. Build the model " + "first (e.g., by calling it on some test data).");
}
printSummary(this, lineLength, positions, printFn);
}
/**
* Configures and prepares the model for training and evaluation. Compiling
* outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
* or `evaluate` on an un-compiled model will throw an error.
*
* @param args a `ModelCompileArgs` specifying the loss, optimizer, and
* metrics to be used for fitting and evaluating this model.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.compile = function compile(args) {
var _this2 = this;
if (args.loss == null) {
args.loss = [];
}
this.loss = args.loss;
if (typeof args.optimizer === 'string') {
this.optimizer_ = getOptimizer(args.optimizer);
this.isOptimizerOwned = true;
} else {
if (!(args.optimizer instanceof Optimizer)) {
throw new ValueError("User-defined optimizer must be an instance of tf.Optimizer.");
}
this.optimizer_ = args.optimizer;
this.isOptimizerOwned = false;
} // TODO(cais): Add lossWeights.
// TODO(cais): Add sampleWeightMode.
// Prepare loss functions.
var lossFunctions = [];
if (!Array.isArray(args.loss) && typeof args.loss !== 'string' && typeof args.loss !== 'function') {
args.loss = args.loss;
for (var name in args.loss) {
if (this.outputNames.indexOf(name) === -1) {
throw new ValueError("Unknown entry in loss dictionary: \"" + name + "\". " + ("Only expected the following keys: " + this.outputNames));
}
}
for (var _iterator3 = _createForOfIteratorHelperLoose(this.outputNames), _step3; !(_step3 = _iterator3()).done;) {
var _name = _step3.value;
if (args.loss[_name] == null) {
console.warn("Output \"" + _name + "\" is missing from loss dictionary. We assume " + "this was done on purpose, and we will not be expecting data " + ("to be passed to " + _name + " during training"));
}
lossFunctions.push(get$3(args.loss[_name]));
}
} else if (Array.isArray(args.loss)) {
if (args.loss.length !== this.outputs.length) {
throw new ValueError("When passing an Array as loss, it should have one entry per " + ("model output. The model has " + this.outputs.length + " output(s), ") + ("but you passed loss=" + args.loss + "."));
}
var theLosses = args.loss;
lossFunctions = theLosses.map(function (l) {
return get$3(l);
});
} else {
var lossFunction = get$3(args.loss);
this.outputs.forEach(function (_) {
lossFunctions.push(lossFunction);
});
}
this.lossFunctions = lossFunctions;
this.feedOutputNames = [];
this.feedOutputShapes = [];
this.feedLossFns = [];
for (var i = 0; i < this.outputs.length; ++i) {
// TODO(cais): Logic for skipping target(s).
var shape = this.internalOutputShapes[i];
var _name2 = this.outputNames[i];
this.feedOutputNames.push(_name2);
this.feedOutputShapes.push(shape);
this.feedLossFns.push(this.lossFunctions[i]);
} // TODO(cais): Add logic for output masks.
// TODO(cais): Add logic for sample weights.
var skipTargetIndices = []; // Prepare metrics.
this.metrics = args.metrics; // TODO(cais): Add weightedMetrics.
this.metricsNames = ['loss'];
this.metricsTensors = []; // Compute total loss.
// Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
// Here, metricsTensors are TypeScript functions. This difference is due
// to the difference in symbolic/imperative property of the backends.
nameScope('loss', function () {
for (var _i = 0; _i < _this2.outputs.length; ++_i) {
if (skipTargetIndices.indexOf(_i) !== -1) {
continue;
} // TODO(cais): Add weightedLoss, sampleWeight and mask.
// The following line should be weightedLoss
var weightedLoss = _this2.lossFunctions[_i];
if (_this2.outputs.length > 1) {
_this2.metricsTensors.push([weightedLoss, _i]);
_this2.metricsNames.push(_this2.outputNames[_i] + '_loss');
}
} // Porting Note: Due to the imperative nature of the backend, we calculate
// the regularizer penalties in the totalLossFunction, instead of here.
});
var nestedMetrics = collectMetrics(args.metrics, this.outputNames); // TODO(cais): Add nestedWeightedMetrics.
/**
* Helper function used in loop below.
*/
var appendMetric = function appendMetric(outputIndex, metricName, metricTensor) {
if (_this2.outputNames.length > 1) {
metricName = _this2.outputNames[outputIndex] + '_' + metricName;
}
_this2.metricsNames.push(metricName);
_this2.metricsTensors.push([metricTensor, outputIndex]);
};
nameScope('metric', function () {
var _loop = function _loop(_i2) {
if (skipTargetIndices.indexOf(_i2) !== -1) {
return "continue";
}
var outputMetrics = nestedMetrics[_i2]; // TODO(cais): Add weights and outputWeightedMetrics.
// TODO(cais): Add optional arg `weights` to the following function.
var handleMetrics = function handleMetrics(metrics) {
var metricNamePrefix = '';
var metricName;
var accFn;
var weightedMetricFn; // TODO(cais): Use 'weights_' for weighted metrics.
for (var _iterator4 = _createForOfIteratorHelperLoose(metrics), _step4; !(_step4 = _iterator4()).done;) {
var metric = _step4.value;
if (typeof metric === 'string' && ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !== -1) {
var outputShape = _this2.internalOutputShapes[_i2];
if (outputShape[outputShape.length - 1] === 1 || _this2.lossFunctions[_i2] === binaryCrossentropy) {
// case: binary accuracy/crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = binaryAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = binaryCrossentropy$1;
}
} else if (_this2.lossFunctions[_i2] === sparseCategoricalCrossentropy) {
// case: categorical accuracy / crossentropy with sparse
// targets.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = sparseCategoricalAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = sparseCategoricalCrossentropy$1;
}
} else {
// case: categorical accuracy / crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = categoricalAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = categoricalCrossentropy$1;
}
}
var suffix = void 0;
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
suffix = 'acc';
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
suffix = 'ce';
} // TODO(cais): Add weighting actually.
weightedMetricFn = accFn;
metricName = metricNamePrefix + suffix;
} else {
var metricFn = get$4(metric); // TODO(cais): Add weighting actually.
weightedMetricFn = metricFn;
metricName = metricNamePrefix + getLossOrMetricName(metric);
} // TODO(cais): Add weighting and masking to metricResult.
var metricResult = void 0;
nameScope(metricName, function () {
metricResult = weightedMetricFn;
});
appendMetric(_i2, metricName, metricResult);
}
};
handleMetrics(outputMetrics); // TODO(cais): Call handleMetrics with weights.
};
for (var _i2 = 0; _i2 < _this2.outputs.length; ++_i2) {
var _ret = _loop(_i2);
if (_ret === "continue") continue;
}
}); // Porting Notes: Given the imperative backend of tfjs-core,
// there is no need for constructing the symbolic graph and placeholders.
this.collectedTrainableWeights = this.trainableWeights;
}
/**
* Check trainable weights count consistency.
*
* This will raise a warning if `this.trainableWeights` and
* `this.collectedTrainableWeights` are inconsistent (i.e., have different
* numbers of parameters).
* Inconsistency will typically arise when one modifies `model.trainable`
* without calling `model.compile()` again.
*/
;
_proto.checkTrainableWeightsConsistency = function checkTrainableWeightsConsistency() {
if (this.collectedTrainableWeights == null) {
return;
}
if (this.trainableWeights.length !== this.collectedTrainableWeights.length) {
console.warn('Discrepancy between trainableweights and collected trainable ' + 'weights. Did you set `model.trainable` without calling ' + '`model.compile()` afterwards?');
}
}
/**
* Returns the loss value & metrics values for the model in test mode.
*
* Loss and metrics are specified during `compile()`, which needs to happen
* before calls to `evaluate()`.
*
* Computation is done in batches.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const result = model.evaluate(
* tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
* result.print();
* ```
*
* @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
* model has multiple inputs.
* @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
* model has multiple outputs.
* @param args A `ModelEvaluateArgs`, containing optional fields.
*
* @return `Scalar` test loss (if the model has a single output and no
* metrics) or `Array` of `Scalar`s (if the model has multiple outputs
* and/or metrics). The attribute `model.metricsNames`
* will give you the display labels for the scalar outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.evaluate = function evaluate(x, y, args) {
if (args === void 0) {
args = {};
}
var batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize); // TODO(cais): Standardize `config.sampleWeights` as well.
// Validate user data.
var checkBatchAxis = true;
var standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
try {
// TODO(cais): If uses `useLearningPhase`, set the corresponding element
// of the input to 0.
var ins = standardizedOuts[0].concat(standardizedOuts[1]);
this.makeTestFunction();
var f = this.testFunction;
var testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
return singletonOrArray(testOuts);
} finally {
disposeNewTensors(standardizedOuts[0], x);
disposeNewTensors(standardizedOuts[1], y);
}
} // TODO(cais): Add code snippet below once real dataset objects are
// available.
/**
* Evaluate model using a dataset object.
*
* Note: Unlike `evaluate()`, this method is asynchronous (`async`);
*
* @param dataset A dataset object. Its `iterator()` method is expected
* to generate a dataset iterator object, the `next()` method of which
* is expected to produce data batches for evaluation. The return value
* of the `next()` call ought to contain a boolean `done` field and a
* `value` field. The `value` field is expected to be an array of two
* `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
* case is for models with exactly one input and one output (e.g..
* a sequential model). The latter case is for models with multiple
* inputs and/or multiple outputs. Of the two items in the array, the
* first is the input feature(s) and the second is the output target(s).
* @param args A configuration object for the dataset-based evaluation.
* @returns Loss and metric values as an Array of `Scalar` objects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.evaluateDataset =
/*#__PURE__*/
function () {
var _evaluateDataset2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(dataset, args) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
this.makeTestFunction();
return _context.abrupt("return", evaluateDataset(this, dataset, args));
case 2:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function evaluateDataset$1(_x, _x2) {
return _evaluateDataset2.apply(this, arguments);
}
return evaluateDataset$1;
}()
/**
* Get number of samples provided for training, evaluation or prediction.
*
* @param ins Input `tf.Tensor`.
* @param batchSize Integer batch size, optional.
* @param steps Total number of steps (batches of samples) before
* declaring loop finished. Optional.
* @param stepsName The public API's parameter name for `steps`.
* @returns Number of samples provided.
*/
;
_proto.checkNumSamples = function checkNumSamples(ins, batchSize, steps, stepsName) {
if (stepsName === void 0) {
stepsName = 'steps';
}
var numSamples;
if (steps != null) {
numSamples = null;
if (batchSize != null) {
throw new ValueError("If " + stepsName + " is set, batchSize must be null or undefined." + ("Got batchSize = " + batchSize));
}
} else if (ins != null) {
if (Array.isArray(ins)) {
numSamples = ins[0].shape[0];
} else {
numSamples = ins.shape[0];
}
} else {
throw new ValueError("Either the input data should have a defined shape, or " + (stepsName + " shoud be specified."));
}
return numSamples;
}
/**
* Execute internal tensors of the model with input data feed.
* @param inputs Input data feed. Must match the inputs of the model.
* @param outputs Names of the output tensors to be fetched. Must match
* names of the SymbolicTensors that belong to the graph.
* @returns Fetched values for `outputs`.
*/
;
_proto.execute = function execute$1(inputs, outputs) {
if (Array.isArray(outputs) && outputs.length === 0) {
throw new ValueError('`outputs` is an empty Array, which is not allowed.');
}
var outputsIsArray = Array.isArray(outputs);
var outputNames = outputsIsArray ? outputs : [outputs];
var outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames); // Format the input into a FeedDict.
var feedDict = new FeedDict();
if (inputs instanceof Tensor) {
inputs = [inputs];
}
if (Array.isArray(inputs)) {
if (inputs.length !== this.inputs.length) {
throw new ValueError("The number of inputs provided (" + inputs.length + ") " + "does not match the number of inputs of this model " + ("(" + this.inputs.length + ")."));
}
for (var i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
} else {
for (var _iterator5 = _createForOfIteratorHelperLoose(this.inputs), _step5; !(_step5 = _iterator5()).done;) {
var input = _step5.value;
var tensorValue = inputs[input.name];
if (tensorValue == null) {
throw new ValueError("No value is provided for the model's input " + input.name);
}
feedDict.add(input, tensorValue);
}
} // Run execution.
var executeOutputs = execute(outputSymbolicTensors, feedDict);
return outputsIsArray ? executeOutputs : executeOutputs[0];
}
/**
* Retrieve the model's internal symbolic tensors from symbolic-tensor names.
*/
;
_proto.retrieveSymbolicTensors = function retrieveSymbolicTensors(symbolicTensorNames) {
var outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
var outputsRemaining = symbolicTensorNames.length;
for (var _iterator6 = _createForOfIteratorHelperLoose(this.layers), _step6; !(_step6 = _iterator6()).done;) {
var layer = _step6.value;
var layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
var layerOutputNames = layerOutputs.map(function (output) {
return output.name;
});
for (var i = 0; i < symbolicTensorNames.length; ++i) {
var index = layerOutputNames.indexOf(symbolicTensorNames[i]);
if (index !== -1) {
outputSymbolicTensors[i] = layerOutputs[index];
outputsRemaining--;
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining > 0) {
var remainingNames = [];
outputSymbolicTensors.forEach(function (tensor, i) {
if (tensor == null) {
remainingNames.push(symbolicTensorNames[i]);
}
});
throw new ValueError("Cannot find SymbolicTensors for output name(s): " + ("" + JSON.stringify(remainingNames)));
}
return outputSymbolicTensors;
}
/**
* Helper method to loop over some data in batches.
*
* Porting Note: Not using the functional approach in the Python equivalent
* due to the imperative backend.
* Porting Note: Does not support step mode currently.
*
* @param ins: input data
* @param batchSize: integer batch size.
* @param verbose: verbosity model
* @returns: Predictions as `tf.Tensor` (if a single output) or an `Array` of
* `tf.Tensor` (if multipe outputs).
*/
;
_proto.predictLoop = function predictLoop(ins, batchSize, verbose) {
var _this3 = this;
if (batchSize === void 0) {
batchSize = 32;
}
if (verbose === void 0) {
verbose = false;
}
return tidy(function () {
var numSamples = _this3.checkNumSamples(ins);
if (verbose) {
throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
} // Sample-based predictions.
// Porting Note: Tensor currently does not support sliced assignments as
// in numpy, e.g., x[1:3] = y. Therefore we use concatenation while
// iterating over the batches.
var batches = makeBatches(numSamples, batchSize);
var outsBatches = _this3.outputs.map(function (output) {
return [];
}); // TODO(cais): Can the scope() be pushed down inside the for loop?
var _loop2 = function _loop2(batchIndex) {
var batchOuts = tidy(function () {
var batchStart = batches[batchIndex][0];
var batchEnd = batches[batchIndex][1]; // TODO(cais): Take care of the case of the last element is a flag for
// training/test.
var insBatch = sliceArrays(ins, batchStart, batchEnd); // Construct the feeds for execute();
var feeds = [];
if (Array.isArray(insBatch)) {
for (var i = 0; i < insBatch.length; ++i) {
feeds.push({
key: _this3.inputs[i],
value: insBatch[i]
});
}
} else {
feeds.push({
key: _this3.inputs[0],
value: insBatch
});
}
var feedDict = new FeedDict(feeds);
return execute(_this3.outputs, feedDict);
});
batchOuts.forEach(function (batchOut, i) {
return outsBatches[i].push(batchOut);
});
};
for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
_loop2(batchIndex);
}
return singletonOrArray(outsBatches.map(function (batches) {
return concat(batches, 0);
}));
});
}
/**
* Generates output predictions for the input samples.
*
* Computation is done in batches.
*
* Note: the "step" mode of predict() is currently not supported.
* This is because the TensorFlow.js core backend is imperative only.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.predict(tf.ones([8, 10]), {batchSize: 4}).print();
* ```
*
* @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
* the model has multiple inputs.
* @param args A `ModelPredictArgs` object containing optional fields.
*
* @return Prediction results as a `tf.Tensor`(s).
*
* @exception ValueError In case of mismatch between the provided input data
* and the model's expectations, or in case a stateful model receives a
* number of samples that is not a multiple of the batch size.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.predict = function predict(x, args) {
if (args === void 0) {
args = {};
}
var xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
try {
// TODO(cais): Take care of stateful models.
// if (this.stateful) ...
// TODO(cais): Take care of the learning_phase boolean flag.
// if (this.useLearningPhase) ...
var batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
return this.predictLoop(xsRank2OrHigher, batchSize);
} finally {
disposeNewTensors(xsRank2OrHigher, x);
}
}
/**
* Returns predictions for a single batch of samples.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.predictOnBatch(tf.ones([8, 10])).print();
* ```
* @param x: Input samples, as a Tensor (for models with exactly one
* input) or an array of Tensors (for models with more than one input).
* @return Tensor(s) of predictions
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.predictOnBatch = function predictOnBatch(x) {
checkInputData(x, this.inputNames, this.feedInputShapes, true); // TODO(cais): Take care of the learning_phase boolean flag.
// if (this.useLearningPhase) ...
var batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
return this.predictLoop(x, batchSize);
};
_proto.standardizeUserDataXY = function standardizeUserDataXY(x, y, checkBatchAxis, batchSize) {
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
// TODO(cais): Add sampleWeight, classWeight
if (this.optimizer_ == null) {
throw new RuntimeError('You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileArgs).');
}
var outputShapes = [];
for (var i = 0; i < this.feedOutputShapes.length; ++i) {
var outputShape = this.feedOutputShapes[i];
var lossFn = this.feedLossFns[i];
if (lossFn === sparseCategoricalCrossentropy) {
outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
} else {
// Porting Note: Because of strong typing `lossFn` must be a function.
outputShapes.push(outputShape);
}
}
x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target'); // TODO(cais): Standardize sampleWeights & classWeights.
checkArrayLengths(x, y, null); // TODO(cais): Check sampleWeights as well.
checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
if (this.stateful && batchSize != null && batchSize > 0) {
if (x[0].shape[0] % batchSize !== 0) {
throw new ValueError("In a stateful network, you should only pass inputs with a " + "number of samples that is divisible by the batch size " + (batchSize + ". Found: " + x[0].shape[0] + " sample(s)."));
}
}
return [x, y];
};
_proto.standardizeUserData = /*#__PURE__*/function () {
var _standardizeUserData = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(x, y, sampleWeight, classWeight, checkBatchAxis, batchSize) {
var _this$standardizeUser, standardXs, standardYs, standardSampleWeights, classWeights, i;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (checkBatchAxis === void 0) {
checkBatchAxis = true;
}
_this$standardizeUser = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize), standardXs = _this$standardizeUser[0], standardYs = _this$standardizeUser[1]; // TODO(cais): Handle sampleWeights.
if (!(sampleWeight != null)) {
_context2.next = 4;
break;
}
throw new Error('sample weight is not supported yet.');
case 4:
standardSampleWeights = null;
if (!(classWeight != null)) {
_context2.next = 18;
break;
}
classWeights = standardizeClassWeights(classWeight, this.outputNames);
standardSampleWeights = [];
i = 0;
case 9:
if (!(i < classWeights.length)) {
_context2.next = 18;
break;
}
_context2.t0 = standardSampleWeights;
_context2.next = 13;
return standardizeWeights(standardYs[i], null, classWeights[i]);
case 13:
_context2.t1 = _context2.sent;
_context2.t0.push.call(_context2.t0, _context2.t1);
case 15:
++i;
_context2.next = 9;
break;
case 18:
return _context2.abrupt("return", [standardXs, standardYs, standardSampleWeights]);
case 19:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function standardizeUserData(_x3, _x4, _x5, _x6, _x7, _x8) {
return _standardizeUserData.apply(this, arguments);
}
return standardizeUserData;
}()
/**
* Loop over some test data in batches.
* @param f A Function returning a list of tensors.
* @param ins Array of tensors to be fed to `f`.
* @param batchSize Integer batch size or `null` / `undefined`.
* @param verbose verbosity mode.
* @param steps Total number of steps (batches of samples) before
* declaring test finished. Ignored with the default value of `null` /
* `undefined`.
* @returns Array of Scalars.
*/
;
_proto.testLoop = function testLoop(f, ins, batchSize, verbose, steps) {
var _this4 = this;
if (verbose === void 0) {
verbose = 0;
}
return tidy(function () {
var numSamples = _this4.checkNumSamples(ins, batchSize, steps, 'steps');
var outs = [];
if (verbose > 0) {
throw new NotImplementedError('Verbose mode is not implemented yet.');
} // TODO(cais): Use `indicesForConversionToDense' to prevent slow down.
if (steps != null) {
throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
} else {
var batches = makeBatches(numSamples, batchSize);
var indexArray = tensor1d(range$1(0, numSamples));
for (var batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
var batchStart = batches[batchIndex][0];
var batchEnd = batches[batchIndex][1];
var batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart); // TODO(cais): In ins, train flag can be a number, instead of an
// Tensor? Do we need to handle this in tfjs-layers?
var insBatch = sliceArraysByIndices(ins, batchIds);
var batchOuts = f(insBatch);
if (batchIndex === 0) {
for (var i = 0; i < batchOuts.length; ++i) {
outs.push(scalar(0));
}
}
for (var _i3 = 0; _i3 < batchOuts.length; ++_i3) {
var batchOut = batchOuts[_i3];
outs[_i3] = add$1(outs[_i3], mul(batchEnd - batchStart, batchOut));
}
}
for (var _i4 = 0; _i4 < outs.length; ++_i4) {
outs[_i4] = div(outs[_i4], numSamples);
}
}
return outs;
});
};
_proto.getDedupedMetricsNames = function getDedupedMetricsNames() {
var outLabels = this.metricsNames; // Rename duplicated metrics names (can happen with an output layer
// shared among multiple dataflows).
var dedupedOutLabels = [];
for (var i = 0; i < outLabels.length; ++i) {
var label = outLabels[i];
var newLabel = label;
if (count(outLabels, label) > 1) {
var dupIndex = count(outLabels.slice(0, i), label);
newLabel += "_" + dupIndex;
}
dedupedOutLabels.push(newLabel);
}
return dedupedOutLabels;
}
/**
* Creates a function that performs the following actions:
*
* 1. computes the losses
* 2. sums them to get the total loss
* 3. call the optimizer computes the gradients of the LayersModel's
* trainable weights w.r.t. the total loss and update the variables
* 4. calculates the metrics
* 5. returns the values of the losses and metrics.
*/
;
_proto.makeTrainFunction = function makeTrainFunction() {
var _this5 = this;
return function (data) {
var lossValues = [];
var inputs = data.slice(0, _this5.inputs.length);
var targets = data.slice(_this5.inputs.length, _this5.inputs.length + _this5.outputs.length);
var sampleWeights = data.slice(_this5.inputs.length + _this5.outputs.length, _this5.inputs.length + _this5.outputs.length * 2);
var metricsValues = []; // Create a function that computes the total loss based on the
// inputs. This function is used for obtaining gradients through
// backprop.
var totalLossFunction = function totalLossFunction() {
var feeds = [];
for (var i = 0; i < _this5.inputs.length; ++i) {
feeds.push({
key: _this5.inputs[i],
value: inputs[i]
});
}
var feedDict = new FeedDict(feeds);
var outputs = execute(_this5.outputs, feedDict, {
'training': true
}); // TODO(cais): Take care of the case of multiple outputs from a
// single layer?
var totalLoss;
for (var _i5 = 0; _i5 < _this5.lossFunctions.length; ++_i5) {
var lossFunction = _this5.lossFunctions[_i5];
var loss = lossFunction(targets[_i5], outputs[_i5]);
if (sampleWeights[_i5] != null) {
loss = computeWeightedLoss$1(loss, sampleWeights[_i5]);
} // TODO(cais): push Scalar instead.
var meanLoss = mean(loss); // TODO(cais): Use a scope() instead, to avoid ownership.
lossValues.push(meanLoss);
if (_i5 === 0) {
totalLoss = loss;
} else {
totalLoss = add$1(totalLoss, loss);
}
} // Compute the metrics.
// TODO(cais): These should probably be calculated outside
// totalLossFunction to benefit speed?
for (var _i6 = 0; _i6 < _this5.metricsTensors.length; ++_i6) {
var weightedMetric = void 0;
if (_this5.outputs.length > 1 && _i6 < _this5.outputs.length) {
weightedMetric = lossValues[_i6];
} else {
var metric = _this5.metricsTensors[_i6][0];
var outputIndex = _this5.metricsTensors[_i6][1];
weightedMetric = mean(metric(targets[outputIndex], outputs[outputIndex]));
}
keep(weightedMetric); // TODO(cais): Use a scope() instead, to avoid ownership.
metricsValues.push(weightedMetric);
}
totalLoss = mean(totalLoss); // Add regularizer penalties.
_this5.calculateLosses().forEach(function (regularizerLoss) {
totalLoss = add$1(totalLoss, regularizerLoss);
});
return totalLoss;
};
var variables = _this5.collectedTrainableWeights.map(function (param) {
return param.read();
});
var returnCost = true;
var totalLossValue = _this5.optimizer_.minimize(totalLossFunction, returnCost, variables);
return [totalLossValue].concat(metricsValues);
};
}
/**
* Create a function which, when invoked with an array of `tf.Tensor`s as a
* batch of inputs, returns the prespecified loss and metrics of the model
* under the batch of input data.
*/
;
_proto.makeTestFunction = function makeTestFunction() {
var _this6 = this;
this.testFunction = function (data) {
return tidy(function () {
var valOutputs = [];
var totalLoss;
var inputs = data.slice(0, _this6.inputs.length);
var targets = data.slice(_this6.inputs.length, _this6.inputs.length + _this6.outputs.length);
var feeds = [];
for (var i = 0; i < _this6.inputs.length; ++i) {
feeds.push({
key: _this6.inputs[i],
value: inputs[i]
});
}
var feedDict = new FeedDict(feeds);
var outputs = execute(_this6.outputs, feedDict); // Compute total loss.
for (var _i7 = 0; _i7 < _this6.lossFunctions.length; ++_i7) {
var lossFunction = _this6.lossFunctions[_i7]; // TODO(cais): Add sample weighting and replace the simple
// averaging.
var loss = mean(lossFunction(targets[_i7], outputs[_i7]));
if (_i7 === 0) {
totalLoss = loss;
} else {
totalLoss = add$1(totalLoss, loss);
}
valOutputs.push(totalLoss);
} // Compute the metrics.
for (var _i8 = 0; _i8 < _this6.metricsTensors.length; ++_i8) {
var metric = _this6.metricsTensors[_i8][0];
var outputIndex = _this6.metricsTensors[_i8][1]; // TODO(cais): Replace K.mean() with a proper weighting function.
var meanMetric = mean(metric(targets[outputIndex], outputs[outputIndex]));
valOutputs.push(meanMetric);
}
return valOutputs;
});
};
}
/**
* Trains the model for a fixed number of epochs (iterations on a
* dataset).
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* for (let i = 1; i < 5 ; ++i) {
* const h = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
* batchSize: 4,
* epochs: 3
* });
* console.log("Loss after Epoch " + i + " : " + h.history.loss[0]);
* }
* ```
*
* @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
* model has multiple inputs. If all inputs in the model are named, you
* can also pass a dictionary mapping input names to `tf.Tensor`s.
* @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
* the model has multiple outputs. If all outputs in the model are named,
* you can also pass a dictionary mapping output names to `tf.Tensor`s.
* @param args A `ModelFitArgs`, containing optional fields.
*
* @return A `History` instance. Its `history` attribute contains all
* information collected during training.
*
* @exception ValueError In case of mismatch between the provided input
* data and what the model expects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.fit =
/*#__PURE__*/
function () {
var _fit = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(x, y, args) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (args === void 0) {
args = {};
}
return _context3.abrupt("return", fitTensors(this, x, y, args));
case 2:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function fit(_x9, _x10, _x11) {
return _fit.apply(this, arguments);
}
return fit;
}() // TODO(cais): Add code snippet below when it's possible to instantiate
// actual dataset objects.
/**
* Trains the model using a dataset object.
*
* @param dataset A dataset object. Its `iterator()` method is expected
* to generate a dataset iterator object, the `next()` method of which
* is expected to produce data batches for training. The return value
* of the `next()` call ought to contain a boolean `done` field and a
* `value` field. The `value` field is expected to be an array of two
* `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
* case is for models with exactly one input and one output (e.g..
* a sequential model). The latter case is for models with multiple
* inputs and/or multiple outputs.
* Of the two items in the array, the first is the input feature(s) and
* the second is the output target(s).
* @param args A `ModelFitDatasetArgs`, containing optional fields.
*
* @return A `History` instance. Its `history` attribute contains all
* information collected during training.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.fitDataset =
/*#__PURE__*/
function () {
var _fitDataset2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(dataset, args) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
return _context4.abrupt("return", fitDataset(this, dataset, args));
case 1:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function fitDataset$1(_x12, _x13) {
return _fitDataset2.apply(this, arguments);
}
return fitDataset$1;
}()
/**
* Runs a single gradient update on a single batch of data.
*
* This method differs from `fit()` and `fitDataset()` in the following
* regards:
* - It operates on exactly one batch of data.
* - It returns only the loss and matric values, instead of
* returning the batch-by-batch loss and metric values.
* - It doesn't support fine-grained options such as verbosity and
* callbacks.
*
* @param x Input data. It could be one of the following:
* - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
* multiple inputs).
* - An Object mapping input names to corresponding `tf.Tensor` (if the
* model has named inputs).
* @param y Target darta. It could be either a `tf.Tensor` a multiple
* `tf.Tensor`s. It should be consistent with `x`.
* @returns Training loss or losses (in case the model has
* multiple outputs), along with metrics (if any), as numbers.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.trainOnBatch =
/*#__PURE__*/
function () {
var _trainOnBatch = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5(x, y) {
var standardizeOut, inputs, targets, trainFunction, losses, lossValues, _iterator7, _step7, loss, v;
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
_context5.next = 2;
return this.standardizeUserData(x, y);
case 2:
standardizeOut = _context5.sent;
inputs = standardizeOut[0];
targets = standardizeOut[1];
trainFunction = this.makeTrainFunction();
losses = trainFunction(inputs.concat(targets));
lossValues = [];
_iterator7 = _createForOfIteratorHelperLoose(losses);
case 9:
if ((_step7 = _iterator7()).done) {
_context5.next = 17;
break;
}
loss = _step7.value;
_context5.next = 13;
return loss.data();
case 13:
v = _context5.sent;
lossValues.push(v[0]);
case 15:
_context5.next = 9;
break;
case 17:
dispose(losses);
return _context5.abrupt("return", singletonOrArray(lossValues));
case 19:
case "end":
return _context5.stop();
}
}
}, _callee5, this);
}));
function trainOnBatch(_x14, _x15) {
return _trainOnBatch.apply(this, arguments);
}
return trainOnBatch;
}()
/**
* Extract weight values of the model.
*
* @param config: An instance of `io.SaveConfig`, which specifies
* model-saving options such as whether only trainable weights are to be
* saved.
* @returns A `NamedTensorMap` mapping original weight names (i.e.,
* non-uniqueified weight names) to their values.
*/
;
_proto.getNamedWeights = function getNamedWeights(config) {
var namedWeights = [];
var trainableOnly = config != null && config.trainableOnly;
var weights = trainableOnly ? this.trainableWeights : this.weights;
var weightValues = this.getWeights(trainableOnly);
for (var i = 0; i < weights.length; ++i) {
if (trainableOnly && !weights[i].trainable) {
// Optionally skip non-trainable weights.
continue;
}
namedWeights.push({
name: weights[i].originalName,
tensor: weightValues[i]
});
}
return namedWeights;
}
/**
* Setter used for force stopping of LayersModel.fit() (i.e., training).
*
* Example:
*
* ```js
* const input = tf.input({shape: [10]});
* const output = tf.layers.dense({units: 1}).apply(input);
* const model = tf.model({inputs: [input], outputs: [output]});
* model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
* const xs = tf.ones([8, 10]);
* const ys = tf.zeros([8, 1]);
*
* const history = await model.fit(xs, ys, {
* epochs: 10,
* callbacks: {
* onEpochEnd: async (epoch, logs) => {
* if (epoch === 2) {
* model.stopTraining = true;
* }
* }
* }
* });
*
* // There should be only 3 values in the loss array, instead of 10
* values,
* // due to the stopping after 3 epochs.
* console.log(history.history.loss);
* ```
*/
;
_proto.dispose = function dispose() {
var result = _Container.prototype.dispose.call(this);
if (result.refCountAfterDispose === 0 && this.optimizer != null && this.isOptimizerOwned) {
var numTensorsBeforeOptmizerDisposal = memory().numTensors;
this.optimizer_.dispose();
result.numDisposedVariables += numTensorsBeforeOptmizerDisposal - memory().numTensors;
}
return result;
};
_proto.getLossIdentifiers = function getLossIdentifiers() {
var lossNames;
if (typeof this.loss === 'string') {
lossNames = toSnakeCase(this.loss);
} else if (Array.isArray(this.loss)) {
for (var _iterator8 = _createForOfIteratorHelperLoose(this.loss), _step8; !(_step8 = _iterator8()).done;) {
var loss = _step8.value;
if (typeof loss !== 'string') {
throw new Error('Serialization of non-string loss is not supported.');
}
}
lossNames = this.loss.map(function (name) {
return toSnakeCase(name);
});
} else {
var outputNames = Object.keys(this.loss);
lossNames = {};
var _losses = this.loss;
for (var _i9 = 0, _outputNames = outputNames; _i9 < _outputNames.length; _i9++) {
var outputName = _outputNames[_i9];
if (typeof _losses[outputName] === 'string') {
lossNames[outputName] = toSnakeCase(_losses[outputName]);
} else {
throw new Error('Serialization of non-string loss is not supported.');
}
}
}
return lossNames;
};
_proto.getMetricIdentifiers = function getMetricIdentifiers() {
if (typeof this.metrics === 'string' || typeof this.metrics === 'function') {
return [toSnakeCase(getLossOrMetricName(this.metrics))];
} else if (Array.isArray(this.metrics)) {
return this.metrics.map(function (metric) {
return toSnakeCase(getLossOrMetricName(metric));
});
} else {
var metricsIdentifiers = {};
for (var key in this.metrics) {
metricsIdentifiers[key] = toSnakeCase(getLossOrMetricName(this.metrics[key]));
}
return metricsIdentifiers;
}
};
_proto.getTrainingConfig = function getTrainingConfig() {
return {
loss: this.getLossIdentifiers(),
metrics: this.getMetricIdentifiers(),
optimizer_config: {
class_name: this.optimizer.getClassName(),
config: this.optimizer.getConfig()
}
}; // TODO(cais): Add weight_metrics when they are supported.
// TODO(cais): Add sample_weight_mode when it's supported.
// TODO(cais): Add loss_weights when it's supported.
};
_proto.loadTrainingConfig = function loadTrainingConfig(trainingConfig) {
if (trainingConfig.weighted_metrics != null) {
throw new Error('Loading weight_metrics is not supported yet.');
}
if (trainingConfig.loss_weights != null) {
throw new Error('Loading loss_weights is not supported yet.');
}
if (trainingConfig.sample_weight_mode != null) {
throw new Error('Loading sample_weight_mode is not supported yet.');
}
var tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
var optimizer = deserialize$1(tsConfig);
var loss;
if (typeof trainingConfig.loss === 'string') {
loss = toCamelCase(trainingConfig.loss);
} else if (Array.isArray(trainingConfig.loss)) {
loss = trainingConfig.loss.map(function (lossEntry) {
return toCamelCase(lossEntry);
});
} else if (trainingConfig.loss != null) {
loss = {};
for (var key in trainingConfig.loss) {
loss[key] = toCamelCase(trainingConfig.loss[key]);
}
}
var metrics;
if (Array.isArray(trainingConfig.metrics)) {
metrics = trainingConfig.metrics.map(function (metric) {
return toCamelCase(metric);
});
} else if (trainingConfig.metrics != null) {
metrics = {};
for (var _key in trainingConfig.metrics) {
metrics[_key] = toCamelCase(trainingConfig.metrics[_key]);
}
}
this.compile({
loss: loss,
metrics: metrics,
optimizer: optimizer
});
}
/**
* Save the configuration and/or weights of the LayersModel.
*
* An `IOHandler` is an object that has a `save` method of the proper
* signature defined. The `save` method manages the storing or
* transmission of serialized data ("artifacts") that represent the
* model's topology and weights onto or via a specific medium, such as
* file downloads, local storage, IndexedDB in the web browser and HTTP
* requests to a server. TensorFlow.js provides `IOHandler`
* implementations for a number of frequently used saving mediums, such as
* `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
* for more details.
*
* This method also allows you to refer to certain types of `IOHandler`s
* as URL-like string shortcuts, such as 'localstorage://' and
* 'indexeddb://'.
*
* Example 1: Save `model`'s topology and weights to browser [local
* storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('localstorage://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 2. Saving `model`'s topology and weights to browser
* [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('indexeddb://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 3. Saving `model`'s topology and weights as two files
* (`my-model-1.json` and `my-model-1.weights.bin`) downloaded from
* browser.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* const saveResults = await model.save('downloads://my-model-1');
* ```
*
* Example 4. Send `model`'s topology and weights to an HTTP server.
* See the documentation of `tf.io.http` for more details
* including specifying request parameters and implementation of the
* server.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* const saveResults = await model.save('http://my-server/model/upload');
* ```
*
* @param handlerOrURL An instance of `IOHandler` or a URL-like,
* scheme-based string shortcut for `IOHandler`.
* @param config Options for saving the model.
* @returns A `Promise` of `SaveResult`, which summarizes the result of
* the saving, such as byte sizes of the saved artifacts for the model's
* topology and weight values.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
;
_proto.save =
/*#__PURE__*/
function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee6(handlerOrURL, config) {
var handlers, weightDataAndSpecs, returnString, unusedArg, modelConfig, modelArtifacts, includeOptimizer, _weightDataAndSpecs$s, weightType, _yield$io$encodeWeigh, optimizerWeightData, optimizerWeightSpecs, checkSize;
return regeneratorRuntime.wrap(function _callee6$(_context6) {
while (1) {
switch (_context6.prev = _context6.next) {
case 0:
if (!(typeof handlerOrURL === 'string')) {
_context6.next = 9;
break;
}
handlers = getSaveHandlers(handlerOrURL);
if (!(handlers.length === 0)) {
_context6.next = 6;
break;
}
throw new ValueError("Cannot find any save handlers for URL '" + handlerOrURL + "'");
case 6:
if (!(handlers.length > 1)) {
_context6.next = 8;
break;
}
throw new ValueError("Found more than one (" + handlers.length + ") save handlers for " + ("URL '" + handlerOrURL + "'"));
case 8:
handlerOrURL = handlers[0];
case 9:
if (!(handlerOrURL.save == null)) {
_context6.next = 11;
break;
}
throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.');
case 11:
_context6.next = 13;
return encodeWeights(this.getNamedWeights(config));
case 13:
weightDataAndSpecs = _context6.sent;
returnString = false;
unusedArg = null;
modelConfig = this.toJSON(unusedArg, returnString);
modelArtifacts = {
modelTopology: modelConfig,
format: LAYERS_MODEL_FORMAT_NAME,
generatedBy: "TensorFlow.js tfjs-layers v" + version$2,
convertedBy: null
};
includeOptimizer = config == null ? false : config.includeOptimizer;
if (!(includeOptimizer && this.optimizer != null)) {
_context6.next = 34;
break;
}
modelArtifacts.trainingConfig = this.getTrainingConfig();
weightType = 'optimizer';
_context6.t0 = io;
_context6.next = 25;
return this.optimizer.getWeights();
case 25:
_context6.t1 = _context6.sent;
_context6.t2 = weightType;
_context6.next = 29;
return _context6.t0.encodeWeights.call(_context6.t0, _context6.t1, _context6.t2);
case 29:
_yield$io$encodeWeigh = _context6.sent;
optimizerWeightData = _yield$io$encodeWeigh.data;
optimizerWeightSpecs = _yield$io$encodeWeigh.specs;
(_weightDataAndSpecs$s = weightDataAndSpecs.specs).push.apply(_weightDataAndSpecs$s, optimizerWeightSpecs);
weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
case 34:
if (this.userDefinedMetadata != null) {
// Check serialized size of user-defined metadata.
checkSize = true;
checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
}
modelArtifacts.weightData = weightDataAndSpecs.data;
modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
return _context6.abrupt("return", handlerOrURL.save(modelArtifacts));
case 38:
case "end":
return _context6.stop();
}
}
}, _callee6, this);
}));
function save(_x16, _x17) {
return _save.apply(this, arguments);
}
return save;
}()
/**
* Set user-defined metadata.
*
* The set metadata will be serialized together with the topology
* and weights of the model during `save()` calls.
*
* @param setUserDefinedMetadata
*/
;
_proto.setUserDefinedMetadata = function setUserDefinedMetadata(userDefinedMetadata) {
checkUserDefinedMetadata(userDefinedMetadata, this.name);
this.userDefinedMetadata = userDefinedMetadata;
}
/**
* Get user-defined metadata.
*
* The metadata is supplied via one of the two routes:
* 1. By calling `setUserDefinedMetadata()`.
* 2. Loaded during model loading (if the model is constructed
* via `tf.loadLayersModel()`.)
*
* If no user-defined metadata is available from either of the
* two routes, this function will return `undefined`.
*/
;
_proto.getUserDefinedMetadata = function getUserDefinedMetadata() {
return this.userDefinedMetadata;
};
_createClass(LayersModel, [{
key: "stopTraining",
get: function get() {
return this.stopTraining_;
},
set: function set(stop) {
this.stopTraining_ = stop;
}
}, {
key: "optimizer",
get: function get() {
return this.optimizer_;
},
set: function set(optimizer) {
if (this.optimizer_ !== optimizer) {
this.optimizer_ = optimizer;
this.isOptimizerOwned = false;
}
}
}]);
return LayersModel;
}(Container); // The class name is 'Model' rather than 'LayersModel' for backwards
// compatibility since this class name shows up in the serialization format.
/** @nocollapse */
LayersModel.className = 'Model';
registerClass(LayersModel);
/**
* A `tf.Functional` is an alias to `tf.LayersModel`.
*
* See also:
* `tf.LayersModel`, `tf.Sequential`, `tf.loadLayersModel`.
*/
/** @doc {heading: 'Models', subheading: 'Classes'} */
var Functional = /*#__PURE__*/function (_LayersModel) {
_inheritsLoose(Functional, _LayersModel);
function Functional() {
return _LayersModel.apply(this, arguments) || this;
}
return Functional;
}(LayersModel);
Functional.className = 'Functional';
registerClass(Functional);
/**
* Parses a JSON model configuration file and returns a model instance.
*
* ```js
* // This example shows how to serialize a model using `toJSON()` and
* // deserialize it as another model using `tf.models.modelFromJSON()`.
* // Note: this example serializes and deserializes only the topology
* // of the model; the weights of the loaded model will be different
* // from those of the the original model, due to random weight
* // initialization.
* // To load the topology and weights of a model, use `tf.loadLayersModel()`.
* const model1 = tf.sequential();
* model1.add(tf.layers.repeatVector({inputShape: [2], n: 4}));
* // Serialize `model1` as a JSON object.
* const model1JSON = model1.toJSON(null, false);
* model1.summary();
*
* const model2 = await tf.models.modelFromJSON(model1JSON);
* model2.summary();
* ```
*
* @param modelAndWeightsConfig JSON object or string encoding a model and
* weights configuration. It can also be only the topology JSON of the
* model, in which case the weights will not be loaded.
* @param custom_objects Optional dictionary mapping names
* (strings) to custom classes or functions to be
* considered during deserialization.
* @returns A TensorFlow.js Layers `tf.LayersModel` instance (uncompiled).
*/
function modelFromJSON(_x, _x2) {
return _modelFromJSON.apply(this, arguments);
}
/**
* Load a model, including its topology and optionally weights. See the
* Tutorial named "How to import a Keras Model" for usage examples.
*
* Example 1: Save `model`'s topology and weights to browser [local
* storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('localstorage://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 2. Saving `model`'s topology and weights to browser
* [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('indexeddb://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 3. Load a model from user-selected files from HTML
* [file input
* elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
*
* ```js
* // Note: this code snippet will not work without the HTML elements in the
* // page
* const jsonUpload = document.getElementById('json-upload');
* const weightsUpload = document.getElementById('weights-upload');
*
* const model = await tf.loadLayersModel(
* tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
* ```
*
* Example 4. Load a model from an HTTP server.
*
* ```js
* const model = await
* tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
* model.summary();
* ```
*
* @param pathOrIOHandler Can be either of the two formats
* 1. A string path to the `ModelAndWeightsConfig` JSON describing
* the model in the canonical TensorFlow.js format. This path will be
* interpreted as a relative HTTP path, to which `fetch` will be used to
* request the model topology and weight manifest JSON.
* The content of the JSON file is assumed to be a JSON object with the
* following fields and values:
* - 'modelTopology': A JSON object that can be either of:
* 1. a model architecture JSON consistent with the format of the return
* value of `keras.Model.to_json()`
* 2. a full model JSON in the format of `keras.models.save_model()`.
* - 'weightsManifest': A TensorFlow.js weights manifest.
* See the Python converter function `save_model()` for more details.
* It is also assumed that model weights can be accessed from relative
* paths described by the `paths` fields in weights manifest.
* 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
* method.
* @param options Optional configuration arguments for the model loading,
* including:
* - `strict`: Require that the provided weights exactly match those required
* by the layers. Default true. Passing false means that both extra
* weights and missing weights will be silently ignored.
* - `onProgress`: A progress callback of the form:
* `(fraction: number) => void`. This callback can be used to monitor the
* model-loading process.
* @returns A `Promise` of `tf.LayersModel`, with the topology and weights
* loaded.
*/
function _modelFromJSON() {
_modelFromJSON = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5(modelAndWeightsConfig, customObjects) {
var modelTopology, tsConfig, model, weightValues, uniqueWeightValues, _iterator4, _step4, weight;
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
if (!('modelTopology' in modelAndWeightsConfig)) {
modelAndWeightsConfig = {
modelTopology: modelAndWeightsConfig
};
}
modelAndWeightsConfig = modelAndWeightsConfig;
modelTopology = modelAndWeightsConfig.modelTopology;
if (modelTopology['model_config'] != null) {
// If the model-topology JSON contains a 'model_config' field, then it is
// a full model JSON (e.g., from `keras.Model.save()`), which contains
// not only the model's architecture in its 'model_config' field, but
// additional information such as the model's optimizer. We use only the
// 'model_config' field currently.
modelTopology = modelTopology['model_config'];
}
tsConfig = convertPythonicToTs(modelTopology);
model = deserialize$1(tsConfig, customObjects);
if (!(modelAndWeightsConfig.weightsManifest != null)) {
_context5.next = 14;
break;
}
_context5.next = 9;
return loadWeights(modelAndWeightsConfig.weightsManifest, modelAndWeightsConfig.pathPrefix, model.weights.map(function (weight) {
return weight.originalName;
}));
case 9:
weightValues = _context5.sent;
// Map the weights to the unique tensor names generated during model loading
uniqueWeightValues = {};
for (_iterator4 = _createForOfIteratorHelperLoose(model.weights); !(_step4 = _iterator4()).done;) {
weight = _step4.value;
uniqueWeightValues[weight.originalName] = weightValues[weight.originalName];
}
model.loadWeights(uniqueWeightValues); // Dispose temporary weight values.
dispose(weightValues);
case 14:
return _context5.abrupt("return", model);
case 15:
case "end":
return _context5.stop();
}
}
}, _callee5);
}));
return _modelFromJSON.apply(this, arguments);
}
function loadLayersModelInternal(_x3, _x4) {
return _loadLayersModelInternal.apply(this, arguments);
}
/**
* Load a model and optionally its weights, using an IOHandler object.
*
* @param handler The instance of `IOHandler` to be used during the model
* loading.
* @param customObjects Any optional custom objects to be used during model
* loading.
* @param strict Whether the weight loading will be done in strict mode.
* Default: `true`.
*/
function _loadLayersModelInternal() {
_loadLayersModelInternal = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee6(pathOrIOHandler, options) {
var handlers;
return regeneratorRuntime.wrap(function _callee6$(_context6) {
while (1) {
switch (_context6.prev = _context6.next) {
case 0:
if (options == null) {
options = {};
}
if (!(typeof pathOrIOHandler === 'string')) {
_context6.next = 10;
break;
}
handlers = getLoadHandlers(pathOrIOHandler, options);
if (!(handlers.length === 0)) {
_context6.next = 7;
break;
}
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
// TODO(cais): Reformat the args into a single `LoadOptions` once the core
// is refactored.
handlers.push(browserHTTPRequest(pathOrIOHandler, options));
_context6.next = 9;
break;
case 7:
if (!(handlers.length > 1)) {
_context6.next = 9;
break;
}
throw new ValueError("Found more than one (" + handlers.length + ") load handlers for " + ("URL '" + pathOrIOHandler + "'"));
case 9:
pathOrIOHandler = handlers[0];
case 10:
return _context6.abrupt("return", loadLayersModelFromIOHandler(pathOrIOHandler, undefined, options));
case 11:
case "end":
return _context6.stop();
}
}
}, _callee6);
}));
return _loadLayersModelInternal.apply(this, arguments);
}
function loadLayersModelFromIOHandler(_x5, _x6, _x7) {
return _loadLayersModelFromIOHandler.apply(this, arguments);
}
function _loadLayersModelFromIOHandler() {
_loadLayersModelFromIOHandler = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee7(handler, customObjects, options) {
var artifacts, modelTopology, strict, fastWeightInit, model, trainingConfig, _decodeModelAndOptimi, modelWeights, optimizerWeights;
return regeneratorRuntime.wrap(function _callee7$(_context7) {
while (1) {
switch (_context7.prev = _context7.next) {
case 0:
if (options == null) {
options = {};
}
if (!(handler.load == null)) {
_context7.next = 3;
break;
}
throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.');
case 3:
_context7.next = 5;
return handler.load();
case 5:
artifacts = _context7.sent;
modelTopology = artifacts.modelTopology;
if (modelTopology['model_config'] != null) {
modelTopology = modelTopology['model_config'];
}
strict = options.strict == null ? true : options.strict; // If weights are provided and the weight-loading mode is strict, use
// fast weight initialization. This skips costly initializers such as
// 'orthogonal' and saves unnecessary computation in cases where
// the initialized weight values will immediately be overwritten by
// loaded weight values.
fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
model = deserialize$1(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
trainingConfig = artifacts.trainingConfig;
if (trainingConfig != null) {
model.loadTrainingConfig(trainingConfig);
}
if (artifacts.userDefinedMetadata != null) {
model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
} // If weightData is present, load the weights into the model.
if (!(artifacts.weightData != null)) {
_context7.next = 24;
break;
}
if (!(artifacts.weightSpecs == null)) {
_context7.next = 17;
break;
}
throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' + 'Therefore loading of weights cannot proceed.');
case 17:
_decodeModelAndOptimi = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs), modelWeights = _decodeModelAndOptimi.modelWeights, optimizerWeights = _decodeModelAndOptimi.optimizerWeights;
model.loadWeights(modelWeights, strict);
if (!(model.optimizer != null && optimizerWeights.length > 0)) {
_context7.next = 22;
break;
}
_context7.next = 22;
return model.optimizer.setWeights(optimizerWeights);
case 22:
// Dispose temporary weight values.
dispose(modelWeights);
dispose(optimizerWeights.map(function (w) {
return w.tensor;
}));
case 24:
return _context7.abrupt("return", model);
case 25:
case "end":
return _context7.stop();
}
}
}, _callee7);
}));
return _loadLayersModelFromIOHandler.apply(this, arguments);
}
function decodeModelAndOptimizerWeights(buffer, specs) {
var name2Tensor = decodeWeights(buffer, specs);
var modelWeights = {};
var optimizerWeights = [];
specs.forEach(function (spec) {
if (spec.group === 'optimizer') {
optimizerWeights.push({
name: spec.name,
tensor: name2Tensor[spec.name]
});
} else {
modelWeights[spec.name] = name2Tensor[spec.name];
}
});
return {
modelWeights: modelWeights,
optimizerWeights: optimizerWeights
};
}
/**
* A model with a stack of layers, feeding linearly from one to the next.
*
* `tf.sequential` is a factory function that creates an instance of
* `tf.Sequential`.
*
* ```js
* // Define a model for linear regression.
* const model = tf.sequential();
* model.add(tf.layers.dense({units: 1, inputShape: [1]}));
*
* // Prepare the model for training: Specify the loss and the optimizer.
* model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
*
* // Generate some synthetic data for training.
* const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
* const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
*
* // Train the model using the data then do inference on a data point the
* // model hasn't seen:
* await model.fit(xs, ys);
* model.predict(tf.tensor2d([5], [1, 1])).print();
* ```
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
var Sequential = /*#__PURE__*/function (_LayersModel) {
_inheritsLoose(Sequential, _LayersModel);
function Sequential(args) {
var _this;
_this = _LayersModel.call(this, {
inputs: [],
outputs: []
}) || this;
args = args || {};
_this.trainable = true;
_this.built = false; // Set model name.
_this.name = args.name != null ? args.name : getUid('sequential_'); // Add to the model any layers passed to the constructor.
if (args.layers != null) {
for (var _iterator = _createForOfIteratorHelperLoose(args.layers), _step; !(_step = _iterator()).done;) {
var layer = _step.value;
_this.add(layer);
}
}
return _this;
} // Helper function to Sequential.add Throws if the new output shape will be
// invalid.
var _proto = Sequential.prototype;
_proto.checkShape = function checkShape(layer) {
var shape = layer.inboundNodes[0].outputTensors[0].shape;
if (shape.some(function (x) {
return x < 0;
})) {
throw new ValueError('Negative dimension size caused by adding layer ' + (layer.name + " with input shape [") + (layer.inboundNodes[0].inputTensors[0].shape + "]"));
}
}
/**
* Adds a layer instance on top of the layer stack.
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.dense({units: 8, inputShape: [1]}));
* model.add(tf.layers.dense({units: 4, activation: 'relu6'}));
* model.add(tf.layers.dense({units: 1, activation: 'relu6'}));
* // Note that the untrained model is random at this point.
* model.predict(tf.randomNormal([10, 1])).print();
* ```
* @param layer Layer instance.
*
* @exception ValueError In case the `layer` argument does not know its
* input shape.
* @exception ValueError In case the `layer` argument has multiple output
* tensors, or is already connected somewhere else (forbidden in
* `Sequential` models).
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.add = function add(layer) {
var isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
var modelLayer;
if (isLayerModelInstance) {
modelLayer = layer;
if (modelLayer.outputs.length !== 1) {
throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
}
if (modelLayer.inputs.length !== 1) {
throw new ValueError('All layers in a Sequential model ' + 'should have a single input tensor. ' + 'For multi-input layers, ' + 'use the functional API.');
}
}
if (this.outputs.length === 0) {
// first layer in model: check that it is an input layer
if (layer.inboundNodes.length === 0) {
// create an input layer
if (layer.batchInputShape == null) {
throw new ValueError('The first layer in a Sequential model must ' + 'get an `inputShape` or `batchInputShape` argument.');
} // Instantiate the input layer.
var x = Input({
batchShape: layer.batchInputShape,
dtype: layer.dtype,
name: layer.name + '_input'
}); // This will build the current layer and create the node connecting
// the current layer to the input layer we just created.
layer.apply(x);
}
if (isLayerModelInstance) {
this.outputs = modelLayer.outputs;
this.inputs = modelLayer.inputs;
} else {
if (layer.inboundNodes.length !== 1) {
throw new ValueError('A layer added to a Sequential model must not already be ' + ("connected somewhere else. LayersModel received layer " + layer.name + " ") + ("which has " + layer.inboundNodes.length + " pre-existing inbound ") + 'connections.');
}
if (layer.inboundNodes[0].outputTensors.length !== 1) {
throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
}
this.checkShape(layer);
this.outputs = [layer.inboundNodes[0].outputTensors[0]];
this.inputs = getSourceInputs(this.outputs[0]);
}
this.inboundNodes = []; // We create an input node, which we will keep updated
// as we add more layers.
// (This call has side effects.)
// tslint:disable-next-line:no-unused-expression
new Node({
outboundLayer: this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: this.inputs,
outputTensors: this.outputs,
// no model-level masking for now
inputMasks: pyListRepeat(null, this.inputs.length),
outputMasks: [null],
inputShapes: this.inputs.map(function (x) {
return x.shape;
}),
outputShapes: this.outputs[0].shape
});
} else {
var outputTensor = layer.apply(this.outputs[0]);
if (Array.isArray(outputTensor)) {
throw new TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.');
}
this.checkShape(layer);
this.outputs = [outputTensor]; // update self.inbound_nodes
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
this.layers.push(layer);
this.built = false;
}
/**
* Removes the last layer in the model.
*
* @exception TypeError if there are no layers in the model.
*/
;
_proto.pop = function pop() {
if (this.layers.length === 0) {
throw new TypeError('There are no layers in the model.');
}
this.layers.pop();
if (this.layers.length === 0) {
this.outputs = [];
this.inboundNodes = [];
this.outboundNodes = [];
} else {
var lastLayerIndex = this.layers.length - 1;
this.layers[lastLayerIndex].outboundNodes = [];
this.outputs = [this.layers[lastLayerIndex].output]; // update self.inbound_nodes
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
};
_proto.call = function call(inputs, kwargs) {
if (this.model == null) {
this.build();
}
return this.model.call(inputs, kwargs);
};
_proto.build = function build(inputShape) {
// Call `getExactlyOneShape` without using its return value,
// to verify that exactly one input shape is provided.
getExactlyOneShape(inputShape);
if (this.inputs.length === 0 || this.outputs.length === 0) {
throw new TypeError('Sequential model cannot be built: model is empty.' + ' Add some layers first.');
} // actually create the model
this.model = new LayersModel({
inputs: this.inputs,
outputs: this.outputs[0],
name: this.name + '_model'
});
this.model.trainable = this.trainable; // mirror model attributes
this.supportsMasking = this.model.supportsMasking; // TODO(michaelterry): Add caches
this.inputLayers = this.model.inputLayers;
this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
this.outputLayers = this.model.outputLayers;
this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
this.nodesByDepth = this.model.nodesByDepth;
this.containerNodes = this.model.containerNodes;
this.outputNames = this.model.outputNames;
this.inputNames = this.model.inputNames; // TODO(michaelterry): Add feedInputNames, feedInputs, if needed.
// TODO(michaelterry): Add callbackModel if needed.
this.built = true;
};
_proto.countParams = function countParams() {
if (!this.built) {
this.build();
}
return _LayersModel.prototype.countParams.call(this);
}
/**
* Print a text summary of the Sequential model's layers.
*
* The summary includes
* - Name and type of all layers that comprise the model.
* - Output shape(s) of the layers
* - Number of weight parameters of each layer
* - The total number of trainable and non-trainable parameters of the
* model.
*
* ```js
* const model = tf.sequential();
* model.add(
* tf.layers.dense({units: 100, inputShape: [10], activation: 'relu'}));
* model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
*
* model.summary();
* ```
*
* @param lineLength Custom line length, in number of characters.
* @param positions Custom widths of each of the columns, as either
* fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
* of characters (e.g., `[30, 50, 65]`). Each number corresponds to
* right-most (i.e., ending) position of a column.
* @param printFn Custom print function. Can be used to replace the default
* `console.log`. For example, you can use `x => {}` to mute the printed
* messages in the console.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.summary = function summary(lineLength, positions, printFn) {
if (printFn === void 0) {
printFn = console.log;
}
if (!this.built) {
this.build();
}
_LayersModel.prototype.summary.call(this, lineLength, positions, printFn);
}
/**
* Sets the weights of the model.
*
* @param weights Should be a list of Tensors with shapes and types matching
* the output of `model.getWeights()`.
*/
;
_proto.setWeights = function setWeights(weights) {
if (this.model == null) {
this.build();
}
this.model.setWeights(weights);
}
/**
* Returns the loss value & metrics values for the model in test mode.
*
* Loss and metrics are specified during `compile()`, which needs to happen
* before calls to `evaluate()`.
*
* Computation is done in batches.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const result = model.evaluate(tf.ones([8, 10]), tf.ones([8, 1]), {
* batchSize: 4,
* });
* result.print();
* ```
*
* @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
* model has multiple inputs.
* @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
* model has multiple outputs.
* @param args A `ModelEvaluateConfig`, containing optional fields.
*
* @return `Scalar` test loss (if the model has a single output and no
* metrics) or `Array` of `Scalar`s (if the model has multiple outputs
* and/or metrics). The attribute `model.metricsNames`
* will give you the display labels for the scalar outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.evaluate = function evaluate(x, y, args) {
if (args === void 0) {
args = {};
}
if (!this.built) {
throw new RuntimeError('The model needs to be compiled before being used.');
}
return this.model.evaluate(x, y, args);
} // TODO(cais): Add code snippet below once real dataset objects are
// available.
/**
* Evaluate model using a dataset object.
*
* Note: Unlike `evaluate()`, this method is asynchronous (`async`);
*
* @param dataset A dataset object. Its `iterator()` method is expected
* to generate a dataset iterator object, the `next()` method of which
* is expected to produce data batches for evaluation. The return value
* of the `next()` call ought to contain a boolean `done` field and a
* `value` field. The `value` field is expected to be an array of two
* `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
* case is for models with exactly one input and one output (e.g..
* a sequential model). The latter case is for models with multiple
* inputs and/or multiple outputs. Of the two items in the array, the
* first is the input feature(s) and the second is the output target(s).
* @param args A configuration object for the dataset-based evaluation.
* @returns Loss and metric values as an Array of `Scalar` objects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.evaluateDataset =
/*#__PURE__*/
function () {
var _evaluateDataset = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(dataset, args) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (this.built) {
_context.next = 2;
break;
}
throw new RuntimeError('The model needs to be compiled before being used.');
case 2:
return _context.abrupt("return", this.model.evaluateDataset(dataset, args));
case 3:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function evaluateDataset(_x8, _x9) {
return _evaluateDataset.apply(this, arguments);
}
return evaluateDataset;
}()
/**
* Generates output predictions for the input samples.
*
* Computation is done in batches.
*
* Note: the "step" mode of predict() is currently not supported.
* This is because the TensorFlow.js core backend is imperative only.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.predict(tf.ones([2, 10])).print();
* ```
*
* @param x The input data, as a Tensor, or an `Array` of `tf.Tensor`s if
* the model has multiple inputs.
* @param conifg A `ModelPredictConfig` object containing optional fields.
*
* @return `tf.Tensor`(s) of predictions.
*
* @exception ValueError In case of mismatch between the provided input data
* and the model's expectations, or in case a stateful model receives a
* number of samples that is not a multiple of the batch size.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.predict = function predict(x, args) {
if (args === void 0) {
args = {};
}
if (this.model == null) {
this.build();
}
return this.model.predict(x, args);
}
/**
* Returns predictions for a single batch of samples.
*
* @param x: Input samples, as a Tensor, or list of Tensors (if the model
* has multiple inputs).
* @return Tensor(s) of predictions
*/
;
_proto.predictOnBatch = function predictOnBatch(x) {
if (this.model == null) {
this.build();
}
return this.model.predictOnBatch(x);
}
/**
* See `LayersModel.compile`.
*
* @param args
*/
;
_proto.compile = function compile(args) {
this.build();
this.model.compile(args);
this.optimizer_ = this.model.optimizer; // tslint:disable-next-line:no-any
this.isOptimizerOwned = this.model.isOptimizerOwned;
this.loss = this.model.loss;
this.metrics = this.model.metrics; // TODO(cais): Add this.lossWeights, this.sampleWeightMode,
// this.weightedMetrics, this.targets.
this.metricsTensors = this.model.metricsTensors;
this.metricsNames = this.model.metricsNames; // TODO(cais): Add sampleWeights.
};
/**
* Trains the model for a fixed number of epochs (iterations on a dataset).
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const history = await model.fit(tf.ones([8, 10]), tf.ones([8, 1]), {
* batchSize: 4,
* epochs: 3
* });
* console.log(history.history.loss[0]);
* ```
*
* @param x `tf.Tensor` of training data, or an array of `tf.Tensor`s if the
* model has multiple inputs. If all inputs in the model are named, you can
* also pass a dictionary mapping input names to `tf.Tensor`s.
* @param y `tf.Tensor` of target (label) data, or an array of `tf.Tensor`s if
* the model has multiple outputs. If all outputs in the model are named, you
* can also pass a dictionary mapping output names to `tf.Tensor`s.
* @param args A `ModelFitConfig`, containing optional fields.
*
* @return A `History` instance. Its `history` attribute contains all
* information collected during training.
*
* @exception ValueError In case of mismatch between the provided input data
* and what the model expects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
_proto.fit =
/*#__PURE__*/
function () {
var _fit = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(x, y, args) {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (args === void 0) {
args = {};
}
if (this.built) {
_context2.next = 3;
break;
}
throw new RuntimeError('The model needs to be compiled before ' + 'being used.');
case 3:
return _context2.abrupt("return", this.model.fit(x, y, args));
case 4:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function fit(_x10, _x11, _x12) {
return _fit.apply(this, arguments);
}
return fit;
}()
/**
* Trains the model using a dataset object.
*
* ```js
* const xArray = [
* [1, 1, 1, 1, 1, 1, 1, 1, 1],
* [1, 1, 1, 1, 1, 1, 1, 1, 1],
* [1, 1, 1, 1, 1, 1, 1, 1, 1],
* [1, 1, 1, 1, 1, 1, 1, 1, 1],
* ];
* const yArray = [1, 1, 1, 1];
* // Create a dataset from the JavaScript array.
* const xDataset = tf.data.array(xArray);
* const yDataset = tf.data.array(yArray);
* // Zip combines the `x` and `y` Datasets into a single Dataset, the
* // iterator of which will return an object containing of two tensors,
* // corresponding to `x` and `y`. The call to `batch(4)` will bundle
* // four such samples into a single object, with the same keys now pointing
* // to tensors that hold 4 examples, organized along the batch dimension.
* // The call to `shuffle(4)` causes each iteration through the dataset to
* // happen in a different order. The size of the shuffle window is 4.
* const xyDataset = tf.data.zip({xs: xDataset, ys: yDataset})
* .batch(4)
* .shuffle(4);
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [9]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const history = await model.fitDataset(xyDataset, {
* epochs: 4,
* callbacks: {onEpochEnd: (epoch, logs) => console.log(logs.loss)}
* });
* ```
*
* @param dataset A dataset object. Its `iterator()` method is expected to
* generate a dataset iterator object, the `next()` method of which is
* expected to produce data batches for evaluation. The return value of the
* `next()` call ought to contain a boolean `done` field and a `value`
* field.
*
* The `value` field is expected to be an object of with fields
* `xs` and `ys`, which point to the feature tensor and the target tensor,
* respectively. This case is for models with exactly one input and one
* output (e.g.. a sequential model). For example:
* ```js
* {value: {xs: xsTensor, ys: ysTensor}, done: false}
* ```
*
* If the model has multiple inputs, the `xs` field of `value` should
* be an object mapping input names to their respective feature tensors.
* For example:
* ```js
* {
* value: {
* xs: {
* input_1: xsTensor1,
* input_2: xsTensor2
* },
* ys: ysTensor
* },
* done: false
* }
* ```
* If the model has multiple outputs, the `ys` field of `value` should
* be an object mapping output names to their respective target tensors.
* For example:
* ```js
* {
* value: {
* xs: xsTensor,
* ys: {
* output_1: ysTensor1,
* output_2: ysTensor2
* },
* },
* done: false
* }
* ```
* @param args A `ModelFitDatasetArgs`, containing optional fields.
*
* @return A `History` instance. Its `history` attribute contains all
* information collected during training.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
;
_proto.fitDataset =
/*#__PURE__*/
function () {
var _fitDataset = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(dataset, args) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (this.built) {
_context3.next = 2;
break;
}
throw new RuntimeError('The model needs to be compiled before ' + 'being used.');
case 2:
return _context3.abrupt("return", this.model.fitDataset(dataset, args));
case 3:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function fitDataset(_x13, _x14) {
return _fitDataset.apply(this, arguments);
}
return fitDataset;
}()
/**
* Runs a single gradient update on a single batch of data.
*
* This method differs from `fit()` and `fitDataset()` in the following
* regards:
* - It operates on exactly one batch of data.
* - It returns only the loss and matric values, instead of
* returning the batch-by-batch loss and metric values.
* - It doesn't support fine-grained options such as verbosity and
* callbacks.
*
* @param x Input data. It could be one of the following:
* - A `tf.Tensor`, or an Array of `tf.Tensor`s (in case the model has
* multiple inputs).
* - An Object mapping input names to corresponding `tf.Tensor` (if the
* model has named inputs).
* @param y Target darta. It could be either a `tf.Tensor` a multiple
* `tf.Tensor`s. It should be consistent with `x`.
* @returns Training loss or losses (in case the model has
* multiple outputs), along with metrics (if any), as numbers.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.trainOnBatch =
/*#__PURE__*/
function () {
var _trainOnBatch = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(x, y) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
return _context4.abrupt("return", this.model.trainOnBatch(x, y));
case 1:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function trainOnBatch(_x15, _x16) {
return _trainOnBatch.apply(this, arguments);
}
return trainOnBatch;
}()
/* See parent class for JsDoc */
/** @nocollapse */
;
Sequential.fromConfig = function fromConfig(cls, config, customObjects, fastWeightInit) {
if (customObjects === void 0) {
customObjects = {};
}
if (fastWeightInit === void 0) {
fastWeightInit = false;
}
var configArray;
var extraModelConfig = {};
if (config instanceof Array) {
if (!(config[0].className != null) || config[0]['className'] === 'Merge') {
throw new ValueError('Legacy serialization format not supported yet.');
}
configArray = config;
} else {
assert(config['layers'] != null, function () {
return "When the config data for a Sequential model is not an Array, " + "it must be an Object that contains the 'layers' field.";
});
configArray = config['layers'];
delete config['layers'];
extraModelConfig = config;
}
var model = new cls(extraModelConfig);
if (!(model instanceof Sequential)) {
throw new NotImplementedError("Sequential.fromConfig called on non-Sequential input: " + model);
}
for (var _iterator2 = _createForOfIteratorHelperLoose(configArray), _step2; !(_step2 = _iterator2()).done;) {
var conf = _step2.value;
var _customObjects = undefined;
var layer = deserialize$1(conf, _customObjects, fastWeightInit);
if (fastWeightInit) {
layer.setFastWeightInitDuringBuild(true);
}
model.add(layer);
}
return model;
}
/**
* Setter used for force stopping of LayersModel.fit() (i.e., training).
*
* Example:
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.dense({units: 1, inputShape: [10]}));
* model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
* const xs = tf.ones([8, 10]);
* const ys = tf.zeros([8, 1]);
*
* const history = await model.fit(xs, ys, {
* epochs: 10,
* callbacks: {
* onEpochEnd: async (epoch, logs) => {
* if (epoch === 2) {
* model.stopTraining = true;
* }
* }
* }
* });
*
* // There should be only 3 values in the loss array, instead of 10 values,
* // due to the stopping after 3 epochs.
* console.log(history.history.loss);
* ```
*/
;
// TODO(cais): Override get trainableWeights() here
// tslint:disable-next-line:no-any
_proto.getConfig = function getConfig() {
// NOTE(cais): We override the return type of getConfig() to `any` here,
// because the `Sequential` class is a special case among `Container`
// subtypes in that its getConfig() method returns an Array (not a
// dict).
var layers = [];
for (var _iterator3 = _createForOfIteratorHelperLoose(this.layers), _step3; !(_step3 = _iterator3()).done;) {
var layer = _step3.value;
var dict = {};
dict['className'] = layer.getClassName();
dict['config'] = layer.getConfig();
layers.push(dict);
}
return {
name: this.name,
layers: layers
};
};
_createClass(Sequential, [{
key: "optimizer",
get: function get() {
return this.model == null ? undefined : this.model.optimizer;
},
set: function set(optimizer) {
this.model.optimizer = optimizer;
}
}, {
key: "stopTraining",
get: function get() {
if (this.model == null) {
throw new ValueError('Cannot get the stopTraining property of a sequential model before ' + 'it is compiled.');
}
return this.model.stopTraining;
},
set: function set(stop) {
// TODO(cais): When refactoring to remove the composition pattern happens,
// remove this method overriding.
if (this.model == null) {
throw new ValueError('Cannot set the stopTraining property of a sequential model before ' + 'it is compiled.');
}
this.model.stopTraining = stop;
}
}]);
return Sequential;
}(LayersModel);
/** @nocollapse */
Sequential.className = 'Sequential';
registerClass(Sequential);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// class; include exectuable JavaScript code snippets where applicable
// (b/74074458).
// LayersModel and related factory methods.
/**
* A model is a data structure that consists of `Layers` and defines inputs
* and outputs.
*
* The key difference between `tf.model` and `tf.sequential` is that
* `tf.model` is more generic, supporting an arbitrary graph (without
* cycles) of layers. `tf.sequential` is less generic and supports only a linear
* stack of layers.
*
* When creating a `tf.LayersModel`, specify its input(s) and output(s). Layers
* are used to wire input(s) to output(s).
*
* For example, the following code snippet defines a model consisting of
* two `dense` layers, with 10 and 4 units, respectively.
*
* ```js
* // Define input, which has a size of 5 (not including batch dimension).
* const input = tf.input({shape: [5]});
*
* // First dense layer uses relu activation.
* const denseLayer1 = tf.layers.dense({units: 10, activation: 'relu'});
* // Second dense layer uses softmax activation.
* const denseLayer2 = tf.layers.dense({units: 4, activation: 'softmax'});
*
* // Obtain the output symbolic tensor by applying the layers on the input.
* const output = denseLayer2.apply(denseLayer1.apply(input));
*
* // Create the model based on the inputs.
* const model = tf.model({inputs: input, outputs: output});
*
* // The model can be used for training, evaluation and prediction.
* // For example, the following line runs prediction with the model on
* // some fake data.
* model.predict(tf.ones([2, 5])).print();
* ```
* See also:
* `tf.sequential`, `tf.loadLayersModel`.
*
* @doc {heading: 'Models', subheading: 'Creation'}
*/
function model(args) {
return new LayersModel(args);
}
/**
* Creates a `tf.Sequential` model. A sequential model is any model where the
* outputs of one layer are the inputs to the next layer, i.e. the model
* topology is a simple 'stack' of layers, with no branching or skipping.
*
* This means that the first layer passed to a `tf.Sequential` model should have
* a defined input shape. What that means is that it should have received an
* `inputShape` or `batchInputShape` argument, or for some type of layers
* (recurrent, Dense...) an `inputDim` argument.
*
* The key difference between `tf.model` and `tf.sequential` is that
* `tf.sequential` is less generic, supporting only a linear stack of layers.
* `tf.model` is more generic and supports an arbitrary graph (without
* cycles) of layers.
*
* Examples:
*
* ```js
* const model = tf.sequential();
*
* // First layer must have an input shape defined.
* model.add(tf.layers.dense({units: 32, inputShape: [50]}));
* // Afterwards, TF.js does automatic shape inference.
* model.add(tf.layers.dense({units: 4}));
*
* // Inspect the inferred shape of the model's output, which equals
* // `[null, 4]`. The 1st dimension is the undetermined batch dimension; the
* // 2nd is the output size of the model's last layer.
* console.log(JSON.stringify(model.outputs[0].shape));
* ```
*
* It is also possible to specify a batch size (with potentially undetermined
* batch dimension, denoted by "null") for the first layer using the
* `batchInputShape` key. The following example is equivalent to the above:
*
* ```js
* const model = tf.sequential();
*
* // First layer must have a defined input shape
* model.add(tf.layers.dense({units: 32, batchInputShape: [null, 50]}));
* // Afterwards, TF.js does automatic shape inference.
* model.add(tf.layers.dense({units: 4}));
*
* // Inspect the inferred shape of the model's output.
* console.log(JSON.stringify(model.outputs[0].shape));
* ```
*
* You can also use an `Array` of already-constructed `Layer`s to create
* a `tf.Sequential` model:
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 32, inputShape: [50]}),
* tf.layers.dense({units: 4})]
* });
* console.log(JSON.stringify(model.outputs[0].shape));
* ```
*
* @doc {heading: 'Models', subheading: 'Creation'}
*/
function sequential(config) {
return new Sequential(config);
}
/**
* Load a model composed of Layer objects, including its topology and optionally
* weights. See the Tutorial named "How to import a Keras Model" for usage
* examples.
*
* This method is applicable to:
*
* 1. Models created with the `tf.layers.*`, `tf.sequential`, and
* `tf.model` APIs of TensorFlow.js and later saved with the
* `tf.LayersModel.save` method.
* 2. Models converted from Keras or TensorFlow tf.keras using the
* [tensorflowjs_converter](https://github.com/tensorflow/tfjs/tree/master/tfjs-converter).
*
* This mode is *not* applicable to TensorFlow `SavedModel`s or their converted
* forms. For those models, use `tf.loadGraphModel`.
*
* Example 1. Load a model from an HTTP server.
*
* ```js
* const model = await tf.loadLayersModel(
* 'https://storage.googleapis.com/tfjs-models/tfjs/iris_v1/model.json');
* model.summary();
* ```
*
* Example 2: Save `model`'s topology and weights to browser [local
* storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('localstorage://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('localstorage://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 3. Saving `model`'s topology and weights to browser
* [IndexedDB](https://developer.mozilla.org/en-US/docs/Web/API/IndexedDB_API);
* then load it back.
*
* ```js
* const model = tf.sequential(
* {layers: [tf.layers.dense({units: 1, inputShape: [3]})]});
* console.log('Prediction from original model:');
* model.predict(tf.ones([1, 3])).print();
*
* const saveResults = await model.save('indexeddb://my-model-1');
*
* const loadedModel = await tf.loadLayersModel('indexeddb://my-model-1');
* console.log('Prediction from loaded model:');
* loadedModel.predict(tf.ones([1, 3])).print();
* ```
*
* Example 4. Load a model from user-selected files from HTML
* [file input
* elements](https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input/file).
*
* ```js
* // Note: this code snippet will not work without the HTML elements in the
* // page
* const jsonUpload = document.getElementById('json-upload');
* const weightsUpload = document.getElementById('weights-upload');
*
* const model = await tf.loadLayersModel(
* tf.io.browserFiles([jsonUpload.files[0], weightsUpload.files[0]]));
* ```
*
* @param pathOrIOHandler Can be either of the two formats
* 1. A string path to the `ModelAndWeightsConfig` JSON describing
* the model in the canonical TensorFlow.js format. For file://
* (tfjs-node-only), http:// and https:// schemas, the path can be
* either absolute or relative.
* 2. An `tf.io.IOHandler` object that loads model artifacts with its `load`
* method.
* @param options Optional configuration arguments for the model loading,
* including:
* - `strict`: Require that the provided weights exactly match those required
* by the layers. Default true. Passing false means that both extra
* weights and missing weights will be silently ignored.
* - `onProgress`: A function of the signature `(fraction: number) => void',
* that can be used as the progress callback for the model loading.
* @returns A `Promise` of `tf.LayersModel`, with the topology and weights
* loaded.
*
* @doc {heading: 'Models', subheading: 'Loading'}
*/
function loadLayersModel(pathOrIOHandler, options) {
if (options == null) {
options = {};
}
return loadLayersModelInternal(pathOrIOHandler, options);
}
/**
* Used to instantiate an input to a model as a `tf.SymbolicTensor`.
*
* Users should call the `input` factory function for
* consistency with other generator functions.
*
* Example:
*
* ```js
* // Defines a simple logistic regression model with 32 dimensional input
* // and 3 dimensional output.
* const x = tf.input({shape: [32]});
* const y = tf.layers.dense({units: 3, activation: 'softmax'}).apply(x);
* const model = tf.model({inputs: x, outputs: y});
* model.predict(tf.ones([2, 32])).print();
* ```
*
* Note: `input` is only necessary when using `model`. When using
* `sequential`, specify `inputShape` for the first layer or use `inputLayer`
* as the first layer.
*
* @doc {heading: 'Models', subheading: 'Inputs'}
*/
function input(config) {
return Input(config);
}
function registerCallbackConstructor(verbosityLevel, callbackConstructor) {
CallbackConstructorRegistry.registerCallbackConstructor(verbosityLevel, callbackConstructor);
}
/**
* Base class for Activations.
*
* Special note: due to cross-language compatibility reasons, the
* static readonly className field in this family of classes must be set to
* the initialLowerCamelCase name of the activation.
*/
var Activation = /*#__PURE__*/function (_serialization$Serial) {
_inheritsLoose(Activation, _serialization$Serial);
function Activation() {
return _serialization$Serial.apply(this, arguments) || this;
}
var _proto = Activation.prototype;
_proto.getConfig = function getConfig() {
return {};
};
return Activation;
}(Serializable);
/**
* Exponential linear unit (ELU).
* Reference: https://arxiv.org/abs/1511.07289
*/
var Elu$1 = /*#__PURE__*/function (_Activation) {
_inheritsLoose(Elu, _Activation);
function Elu() {
return _Activation.apply(this, arguments) || this;
}
var _proto2 = Elu.prototype;
/**
* Calculate the activation function.
*
* @param x: Input.
* @param alpha: Scaling factor the negative section.
* @return Output of the ELU activation.
*/
_proto2.apply = function apply(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
return elu$1(x, alpha);
};
return Elu;
}(Activation);
/** @nocollapse */
Elu$1.className = 'elu';
registerClass(Elu$1);
/**
* Scaled Exponential Linear Unit. (Klambauer et al., 2017).
* Reference: Self-Normalizing Neural Networks, https://arxiv.org/abs/1706.02515
* Notes:
* - To be used together with the initialization "lecunNormal".
* - To be used together with the dropout variant "AlphaDropout".
*/
var Selu$1 = /*#__PURE__*/function (_Activation2) {
_inheritsLoose(Selu, _Activation2);
function Selu() {
return _Activation2.apply(this, arguments) || this;
}
var _proto3 = Selu.prototype;
_proto3.apply = function apply(x) {
return selu(x);
};
return Selu;
}(Activation);
/** @nocollapse */
Selu$1.className = 'selu';
registerClass(Selu$1);
/**
* Rectified linear unit
*/
var Relu$1 = /*#__PURE__*/function (_Activation3) {
_inheritsLoose(Relu, _Activation3);
function Relu() {
return _Activation3.apply(this, arguments) || this;
}
var _proto4 = Relu.prototype;
_proto4.apply = function apply(x) {
return relu(x);
};
return Relu;
}(Activation);
/** @nocollapse */
Relu$1.className = 'relu';
registerClass(Relu$1);
/**
* Rectified linear unit activation maxing out at 6.0.
*/
var Relu6$1 = /*#__PURE__*/function (_Activation4) {
_inheritsLoose(Relu6, _Activation4);
function Relu6() {
return _Activation4.apply(this, arguments) || this;
}
var _proto5 = Relu6.prototype;
_proto5.apply = function apply(x) {
return tidy(function () {
return minimum(6.0, relu(x));
});
};
return Relu6;
}(Activation);
/** @nocollapse */
Relu6$1.className = 'relu6';
registerClass(Relu6$1); //* Linear activation (no-op) */
var Linear = /*#__PURE__*/function (_Activation5) {
_inheritsLoose(Linear, _Activation5);
function Linear() {
return _Activation5.apply(this, arguments) || this;
}
var _proto6 = Linear.prototype;
_proto6.apply = function apply(x) {
return x;
};
return Linear;
}(Activation);
/** @nocollapse */
Linear.className = 'linear';
registerClass(Linear);
/**
* Sigmoid activation function.
*/
var Sigmoid$1 = /*#__PURE__*/function (_Activation6) {
_inheritsLoose(Sigmoid, _Activation6);
function Sigmoid() {
return _Activation6.apply(this, arguments) || this;
}
var _proto7 = Sigmoid.prototype;
_proto7.apply = function apply(x) {
return sigmoid(x);
};
return Sigmoid;
}(Activation);
/** @nocollapse */
Sigmoid$1.className = 'sigmoid';
registerClass(Sigmoid$1);
/**
* Segment-wise linear approximation of sigmoid.
*/
var HardSigmoid = /*#__PURE__*/function (_Activation7) {
_inheritsLoose(HardSigmoid, _Activation7);
function HardSigmoid() {
return _Activation7.apply(this, arguments) || this;
}
var _proto8 = HardSigmoid.prototype;
_proto8.apply = function apply(x) {
return hardSigmoid(x);
};
return HardSigmoid;
}(Activation);
/** @nocollapse */
HardSigmoid.className = 'hardSigmoid';
registerClass(HardSigmoid);
/**
* Softplus activation function.
*/
var Softplus$1 = /*#__PURE__*/function (_Activation8) {
_inheritsLoose(Softplus, _Activation8);
function Softplus() {
return _Activation8.apply(this, arguments) || this;
}
var _proto9 = Softplus.prototype;
_proto9.apply = function apply(x) {
return softplus(x);
};
return Softplus;
}(Activation);
/** @nocollapse */
Softplus$1.className = 'softplus';
registerClass(Softplus$1);
/**
* Softsign activation function.
*/
var Softsign = /*#__PURE__*/function (_Activation9) {
_inheritsLoose(Softsign, _Activation9);
function Softsign() {
return _Activation9.apply(this, arguments) || this;
}
var _proto10 = Softsign.prototype;
_proto10.apply = function apply(x) {
return softsign(x);
};
return Softsign;
}(Activation);
/** @nocollapse */
Softsign.className = 'softsign';
registerClass(Softsign);
/**
* Hyperbolic tangent function.
*/
var Tanh$1 = /*#__PURE__*/function (_Activation10) {
_inheritsLoose(Tanh, _Activation10);
function Tanh() {
return _Activation10.apply(this, arguments) || this;
}
var _proto11 = Tanh.prototype;
_proto11.apply = function apply(x) {
return tanh$1(x);
};
return Tanh;
}(Activation);
/** @nocollapse */
Tanh$1.className = 'tanh';
registerClass(Tanh$1);
/**
* Softmax activation function
*/
var Softmax$1 = /*#__PURE__*/function (_Activation11) {
_inheritsLoose(Softmax, _Activation11);
function Softmax() {
return _Activation11.apply(this, arguments) || this;
}
var _proto12 = Softmax.prototype;
/**
* Calculate the activation function.
*
* @param x Tensor.
* @param axis Integer, axis along which the softmax normalization is applied.
* Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
* an error.
*
* @returns a Tensor of the same shape as x
*
* @throws ValueError: In case `dim(x) < 2`.
*/
_proto12.apply = function apply(x, axis) {
if (axis === void 0) {
axis = -1;
}
return softmax(x, axis);
};
return Softmax;
}(Activation);
/** @nocollapse */
Softmax$1.className = 'softmax';
registerClass(Softmax$1);
/**
* Log softmax activation function
*/
var LogSoftmax$1 = /*#__PURE__*/function (_Activation12) {
_inheritsLoose(LogSoftmax, _Activation12);
function LogSoftmax() {
return _Activation12.apply(this, arguments) || this;
}
var _proto13 = LogSoftmax.prototype;
/**
* Calculate the activation function of log softmax:
* log( exp(x_i) / sum(exp(x)) )
*
* @param x Tensor.
* @param axis Integer, axis along which the softmax normalization is applied.
* Invalid if < 2, as softmax across 1 (the batch dimension) is assumed to be
* an error.
*
* @returns a Tensor of the same shape as x
*
* @throws ValueError: In case `dim(x) < 2`.
*/
_proto13.apply = function apply(x, axis) {
if (axis === void 0) {
axis = -1;
}
return logSoftmax(x, axis);
};
return LogSoftmax;
}(Activation);
/** @nocollapse */
LogSoftmax$1.className = 'logSoftmax';
registerClass(LogSoftmax$1);
/**
* Swish activation function
*/
var Swish = /*#__PURE__*/function (_Activation13) {
_inheritsLoose(Swish, _Activation13);
function Swish() {
return _Activation13.apply(this, arguments) || this;
}
var _proto14 = Swish.prototype;
/**
* Calculate the activation function.
*
* @param x Tensor.
* @param alpha Scaling factor for the sigmoid function.
* @returns a Tensor of the same shape as x
*/
_proto14.apply = function apply(x, alpha) {
if (alpha === void 0) {
alpha = 1;
}
return tidy(function () {
return mul(sigmoid(mul(x, alpha)), x);
});
};
return Swish;
}(Activation);
/** @nocollapse */
Swish.className = 'swish';
registerClass(Swish);
/**
* Mish activation function
*/
var Mish = /*#__PURE__*/function (_Activation14) {
_inheritsLoose(Mish, _Activation14);
function Mish() {
return _Activation14.apply(this, arguments) || this;
}
var _proto15 = Mish.prototype;
/**
* Calculate the activation function.
*
* @param x Tensor.
* @returns a Tensor of the same shape as x
*/
_proto15.apply = function apply(x) {
return tidy(function () {
return mul(x, tanh$1(softplus(x)));
});
};
return Mish;
}(Activation);
/** @nocollapse */
Mish.className = 'mish';
registerClass(Mish);
function serializeActivation(activation) {
return activation.getClassName();
}
function deserializeActivation(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
}
function getActivation(identifier) {
if (identifier == null) {
var config = {};
config['className'] = 'linear';
config['config'] = {};
return deserializeActivation(config);
}
if (typeof identifier === 'string') {
var _config = {};
_config['className'] = identifier;
_config['config'] = {};
return deserializeActivation(_config);
} else if (identifier instanceof Activation) {
return identifier;
} else {
return deserializeActivation(identifier);
}
}
function assertObjectArgs(args) {
if (args != null && typeof args !== 'object') {
throw new Error("Argument to L1L2 regularizer's constructor is expected to be an " + ("object, but received: " + args));
}
}
/**
* Regularizer base class.
*/
var Regularizer = /*#__PURE__*/function (_serialization$Serial) {
_inheritsLoose(Regularizer, _serialization$Serial);
function Regularizer() {
return _serialization$Serial.apply(this, arguments) || this;
}
return Regularizer;
}(Serializable);
var L1L2 = /*#__PURE__*/function (_Regularizer) {
_inheritsLoose(L1L2, _Regularizer);
function L1L2(args) {
var _this;
_this = _Regularizer.call(this) || this;
assertObjectArgs(args);
_this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
_this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
_this.hasL1 = _this.l1 !== 0;
_this.hasL2 = _this.l2 !== 0;
return _this;
}
/**
* Porting note: Renamed from __call__.
* @param x Variable of which to calculate the regularization score.
*/
var _proto = L1L2.prototype;
_proto.apply = function apply(x) {
var _this2 = this;
return tidy(function () {
var regularization = zeros([1]);
if (_this2.hasL1) {
regularization = add$1(regularization, sum$1(mul(_this2.l1, abs$8(x))));
}
if (_this2.hasL2) {
regularization = add$1(regularization, sum$1(mul(_this2.l2, square$1(x))));
}
return reshape(regularization, []);
});
};
_proto.getConfig = function getConfig() {
return {
'l1': this.l1,
'l2': this.l2
};
}
/** @nocollapse */
;
L1L2.fromConfig = function fromConfig(cls, config) {
return new cls({
l1: config['l1'],
l2: config['l2']
});
};
return L1L2;
}(Regularizer);
/** @nocollapse */
L1L2.className = 'L1L2';
registerClass(L1L2);
function l1(args) {
assertObjectArgs(args);
return new L1L2({
l1: args != null ? args.l1 : null,
l2: 0
});
}
function l2(args) {
assertObjectArgs(args);
return new L1L2({
l2: args != null ? args.l2 : null,
l1: 0
});
} // Maps the JavaScript-like identifier keys to the corresponding keras symbols.
var REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'l1l2': 'L1L2'
};
function serializeRegularizer(constraint) {
return serializeKerasObject(constraint);
}
function deserializeRegularizer(config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
}
function getRegularizer(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
var className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier;
var config = {
className: className,
config: {}
};
return deserializeRegularizer(config);
} else if (identifier instanceof Regularizer) {
return identifier;
} else {
return deserializeRegularizer(identifier);
}
}
var ReLU = /*#__PURE__*/function (_Layer) {
_inheritsLoose(ReLU, _Layer);
function ReLU(args) {
var _this;
_this = _Layer.call(this, args == null ? {} : args) || this;
_this.supportsMasking = true;
if (args != null) {
_this.maxValue = args.maxValue;
}
return _this;
}
var _proto = ReLU.prototype;
_proto.call = function call(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
var output = relu(inputs);
if (this.maxValue != null) {
output = clipByValue(output, 0, this.maxValue);
}
return output;
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto.getConfig = function getConfig() {
var config = {
maxValue: this.maxValue
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return ReLU;
}(Layer);
/** @nocollapse */
ReLU.className = 'ReLU';
registerClass(ReLU);
var LeakyReLU = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(LeakyReLU, _Layer2);
function LeakyReLU(args) {
var _this2;
_this2 = _Layer2.call(this, args == null ? {} : args) || this;
_this2.DEFAULT_ALPHA = 0.3;
if (args == null) {
args = {};
}
_this2.alpha = args.alpha == null ? _this2.DEFAULT_ALPHA : args.alpha;
return _this2;
}
var _proto2 = LeakyReLU.prototype;
_proto2.call = function call(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return leakyRelu(x, this.alpha);
};
_proto2.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto2.getConfig = function getConfig() {
var config = {
alpha: this.alpha
};
var baseConfig = _Layer2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return LeakyReLU;
}(Layer);
/** @nocollapse */
LeakyReLU.className = 'LeakyReLU';
registerClass(LeakyReLU);
var PReLU = /*#__PURE__*/function (_Layer3) {
_inheritsLoose(PReLU, _Layer3);
function PReLU(args) {
var _this3;
_this3 = _Layer3.call(this, args == null ? {} : args) || this;
_this3.DEFAULT_ALPHA_INITIALIZER = 'zeros';
if (args == null) {
args = {};
}
_this3.supportsMasking = true;
_this3.alphaInitializer = getInitializer(args.alphaInitializer || _this3.DEFAULT_ALPHA_INITIALIZER);
_this3.alphaRegularizer = getRegularizer(args.alphaRegularizer);
_this3.alphaConstraint = getConstraint(args.alphaConstraint);
if (args.sharedAxes == null) {
_this3.sharedAxes = null;
} else if (Array.isArray(args.sharedAxes)) {
_this3.sharedAxes = args.sharedAxes;
} else if (typeof args.sharedAxes === 'number') {
_this3.sharedAxes = [args.sharedAxes];
} else {
throw new ValueError("Expected sharedAxes to be a number or an array of numbers, " + ("but got " + args.sharedAxes));
}
return _this3;
}
var _proto3 = PReLU.prototype;
_proto3.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var paramShape = inputShape.slice(1);
if (this.sharedAxes != null) {
for (var _iterator = _createForOfIteratorHelperLoose(this.sharedAxes), _step; !(_step = _iterator()).done;) {
var i = _step.value;
paramShape[i - 1] = 1;
}
}
this.alpha = this.addWeight('alpha', paramShape, 'float32', this.alphaInitializer, this.alphaRegularizer, true, this.alphaConstraint); // Set input spec.
var axes = {};
if (this.sharedAxes != null) {
for (var _i = 1; _i < inputShape.length; ++_i) {
axes[_i] = inputShape[_i];
}
}
this.inputSpec = [new InputSpec({
ndim: inputShape.length,
axes: axes
})];
this.built = true;
};
_proto3.call = function call(inputs, kwargs) {
inputs = getExactlyOneTensor(inputs);
return prelu(inputs, this.alpha.read());
};
_proto3.getConfig = function getConfig() {
var config = {
alphaInitializer: serializeInitializer(this.alphaInitializer),
alphaRegularizer: serializeRegularizer(this.alphaRegularizer),
alphaConstraint: serializeConstraint(this.alphaConstraint),
sharedAxes: this.sharedAxes
};
var baseConfig = _Layer3.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return PReLU;
}(Layer);
/** @nocollapse */
PReLU.className = 'PReLU';
registerClass(PReLU);
var ELU = /*#__PURE__*/function (_Layer4) {
_inheritsLoose(ELU, _Layer4);
function ELU(args) {
var _this4;
_this4 = _Layer4.call(this, args == null ? {} : args) || this;
_this4.DEFAULT_ALPHA = 1.0;
if (args == null) {
args = {};
}
if (args.alpha != null && args.alpha !== _this4.DEFAULT_ALPHA) {
throw new NotImplementedError("Non-default alpha value (" + args.alpha + ") is not supported by the " + "ELU layer yet.");
}
_this4.alpha = args.alpha == null ? _this4.DEFAULT_ALPHA : args.alpha;
return _this4;
}
var _proto4 = ELU.prototype;
_proto4.call = function call(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return elu(x);
};
_proto4.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto4.getConfig = function getConfig() {
var config = {
alpha: this.alpha
};
var baseConfig = _Layer4.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return ELU;
}(Layer);
/** @nocollapse */
ELU.className = 'ELU';
registerClass(ELU);
var ThresholdedReLU = /*#__PURE__*/function (_Layer5) {
_inheritsLoose(ThresholdedReLU, _Layer5);
function ThresholdedReLU(args) {
var _this5;
_this5 = _Layer5.call(this, args == null ? {} : args) || this;
_this5.DEFAULT_THETA = 1.0;
if (args == null) {
args = {};
}
_this5.theta = args.theta == null ? _this5.DEFAULT_THETA : args.theta;
return _this5;
}
var _proto5 = ThresholdedReLU.prototype;
_proto5.call = function call(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return mul(x, cast(greater(x, this.theta), 'float32'));
};
_proto5.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto5.getConfig = function getConfig() {
var config = {
theta: this.theta
};
var baseConfig = _Layer5.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return ThresholdedReLU;
}(Layer);
/** @nocollapse */
ThresholdedReLU.className = 'ThresholdedReLU';
registerClass(ThresholdedReLU);
var Softmax$2 = /*#__PURE__*/function (_Layer6) {
_inheritsLoose(Softmax, _Layer6);
function Softmax(args) {
var _this6;
_this6 = _Layer6.call(this, args == null ? {} : args) || this;
_this6.DEFAULT_AXIS = 1.0;
if (args == null) {
args = {};
}
_this6.softmax = new Softmax$1().apply;
_this6.axis = args.axis == null ? _this6.DEFAULT_AXIS : args.axis;
return _this6;
}
var _proto6 = Softmax.prototype;
_proto6.call = function call(inputs, kwargs) {
var x = getExactlyOneTensor(inputs);
return this.softmax(x, this.axis);
};
_proto6.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto6.getConfig = function getConfig() {
var config = {
axis: this.axis
};
var baseConfig = _Layer6.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Softmax;
}(Layer);
/** @nocollapse */
Softmax$2.className = 'Softmax';
registerClass(Softmax$2);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Transforms a single number of array of numbers into an array of numbers.
* @param value
* @param n: The size of the tuple to be returned.
* @param name: Name of the parameter, used for generating error messages.
* @returns An array of numbers.
*/
function normalizeArray(value, n, name) {
if (typeof value === 'number') {
return pyListRepeat(value, n);
} else {
if (value.length !== n) {
throw new ValueError("The " + name + " argument must be an integer or tuple of " + n + " integers." + (" Received: " + value.length + " elements."));
}
for (var i = 0; i < n; ++i) {
var singleValue = value[i];
if (!isInteger$1(singleValue)) {
throw new ValueError("The " + name + " argument must be an integer or tuple of " + n + (" integers. Received: " + JSON.stringify(value) + " including a") + (" non-integer number " + singleValue));
}
}
return value;
}
}
/**
* Determines output length of a convolution given input length.
* @param inputLength
* @param filterSize
* @param padding
* @param stride
* @param dilation: dilation rate.
*/
function convOutputLength(inputLength, filterSize, padding, stride, dilation) {
if (dilation === void 0) {
dilation = 1;
}
if (inputLength == null) {
return inputLength;
}
var dilatedFilterSize = filterSize + (filterSize - 1) * (dilation - 1);
var outputLength;
if (padding === 'same') {
outputLength = inputLength;
} else {
// VALID
outputLength = inputLength - dilatedFilterSize + 1;
}
return Math.floor((outputLength + stride - 1) / stride);
}
function deconvLength(dimSize, strideSize, kernelSize, padding) {
if (dimSize == null) {
return null;
}
if (padding === 'valid') {
dimSize = dimSize * strideSize + max$6([kernelSize - strideSize, 0]);
} else if (padding === 'same') {
dimSize = dimSize * strideSize;
} else {
throw new ValueError("Unsupport padding mode: " + padding + ".");
}
return dimSize;
}
/**
* Transpose and cast the input before the conv2d.
* @param x Input image tensor.
* @param dataFormat
*/
function preprocessConv2DInput(x, dataFormat) {
// TODO(cais): Cast type to float32 if not.
return tidy(function () {
checkDataFormat(dataFormat);
if (dataFormat === 'channelsFirst') {
return transpose(x, [0, 2, 3, 1]); // NCHW -> NHWC.
} else {
return x;
}
});
}
/**
* Transpose and cast the input before the conv3d.
* @param x Input image tensor.
* @param dataFormat
*/
function preprocessConv3DInput(x, dataFormat) {
return tidy(function () {
checkDataFormat(dataFormat);
if (dataFormat === 'channelsFirst') {
return transpose(x, [0, 2, 3, 4, 1]); // NCDHW -> NDHWC.
} else {
return x;
}
});
}
/**
* 1D-convolution with bias added.
*
* Porting Note: This function does not exist in the Python Keras backend.
* It is exactly the same as `conv2d`, except the added `bias`.
*
* @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
* @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.
* @param bias Bias, rank-3, of shape `[outDepth]`.
* @param strides
* @param padding Padding mode.
* @param dataFormat Data format.
* @param dilationRate
* @returns The result of the 1D convolution.
* @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
*/
function conv1dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = 1;
}
if (padding === void 0) {
padding = 'valid';
}
if (dilationRate === void 0) {
dilationRate = 1;
}
return tidy(function () {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat); // Check the ranks of x, kernel and bias.
if (x.shape.length !== 3) {
throw new ValueError("The input of a conv1dWithBias operation should be 3, but is " + (x.shape.length + " instead."));
}
if (kernel.shape.length !== 3) {
throw new ValueError("The kernel for a conv1dWithBias operation should be 3, but is " + (kernel.shape.length + " instead"));
}
if (bias != null && bias.shape.length !== 1) {
throw new ValueError("The bias for a conv1dWithBias operation should be 1, but is " + (kernel.shape.length + " instead"));
} // TODO(cais): Support CAUSAL padding mode.
if (dataFormat === 'channelsFirst') {
x = transpose(x, [0, 2, 1]); // NCW -> NWC.
}
if (padding === 'causal') {
throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.');
}
var y = conv1d(x, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NWC', dilationRate);
if (bias != null) {
y = biasAdd(y, bias);
}
return y;
});
}
/**
* 1D-convolution.
*
* @param x Input tensor, rank-3, of shape `[batchSize, width, inChannels]`.
* @param kernel Kernel, rank-3, of shape `[filterWidth, inDepth, outDepth]`.s
* @param strides
* @param padding Padding mode.
* @param dataFormat Data format.
* @param dilationRate
* @returns The result of the 1D convolution.
* @throws ValueError, if `x`, `kernel` or `bias` is not of the correct rank.
*/
function conv1d$1(x, kernel, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = 1;
}
if (padding === void 0) {
padding = 'valid';
}
if (dilationRate === void 0) {
dilationRate = 1;
}
return tidy(function () {
checkDataFormat(dataFormat);
return conv1dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
});
}
/**
* 2D Convolution
* @param x
* @param kernel kernel of the convolution.
* @param strides strides array.
* @param padding padding mode. Default to 'valid'.
* @param dataFormat data format. Defaults to 'channelsLast'.
* @param dilationRate dilation rate array.
* @returns Result of the 2D pooling.
*/
function conv2d$2(x, kernel, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1];
}
if (padding === void 0) {
padding = 'valid';
}
return tidy(function () {
checkDataFormat(dataFormat);
return conv2dWithBiasActivation(x, kernel, null, strides, padding, dataFormat, dilationRate);
});
}
/**
* 2D Convolution with an added bias and optional activation.
* Note: This function does not exist in the Python Keras Backend. This function
* is exactly the same as `conv2d`, except the added `bias`.
*/
function conv2dWithBiasActivation(x, kernel, bias, strides, padding, dataFormat, dilationRate, activation) {
if (strides === void 0) {
strides = [1, 1];
}
if (padding === void 0) {
padding = 'valid';
}
if (activation === void 0) {
activation = null;
}
return tidy(function () {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
if (x.rank !== 3 && x.rank !== 4) {
throw new ValueError("conv2dWithBiasActivation expects input to be of rank 3 or 4, " + ("but received " + x.rank + "."));
}
if (kernel.rank !== 3 && kernel.rank !== 4) {
throw new ValueError("conv2dWithBiasActivation expects kernel to be of rank 3 or 4, " + ("but received " + x.rank + "."));
}
var y = preprocessConv2DInput(x, dataFormat);
if (padding === 'causal') {
throw new NotImplementedError('The support for CAUSAL padding mode in conv1dWithBias is not ' + 'implemented yet.');
}
y = conv2d$1({
x: y,
filter: kernel,
strides: strides,
pad: padding === 'same' ? 'same' : 'valid',
dilations: dilationRate,
dataFormat: 'NHWC',
bias: bias,
activation: activation
});
if (dataFormat === 'channelsFirst') {
y = transpose(y, [0, 3, 1, 2]);
}
return y;
});
}
/**
* 3D Convolution.
* @param x
* @param kernel kernel of the convolution.
* @param strides strides array.
* @param padding padding mode. Default to 'valid'.
* @param dataFormat data format. Defaults to 'channelsLast'.
* @param dilationRate dilation rate array.
* @returns Result of the 3D convolution.
*/
function conv3d$1(x, kernel, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1, 1];
}
if (padding === void 0) {
padding = 'valid';
}
return tidy(function () {
checkDataFormat(dataFormat);
return conv3dWithBias(x, kernel, null, strides, padding, dataFormat, dilationRate);
});
}
/**
* 3D Convolution with an added bias.
* Note: This function does not exist in the Python Keras Backend. This function
* is exactly the same as `conv3d`, except the added `bias`.
*/
function conv3dWithBias(x, kernel, bias, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1, 1];
}
if (padding === void 0) {
padding = 'valid';
}
return tidy(function () {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
if (x.rank !== 4 && x.rank !== 5) {
throw new ValueError("conv3dWithBias expects input to be of rank 4 or 5, but received " + (x.rank + "."));
}
if (kernel.rank !== 4 && kernel.rank !== 5) {
throw new ValueError("conv3dWithBias expects kernel to be of rank 4 or 5, but received " + (x.rank + "."));
}
var y = preprocessConv3DInput(x, dataFormat);
if (padding === 'causal') {
throw new NotImplementedError('The support for CAUSAL padding mode in conv3dWithBias is not ' + 'implemented yet.');
}
y = conv3d(y, kernel, strides, padding === 'same' ? 'same' : 'valid', 'NDHWC', dilationRate);
if (bias != null) {
y = biasAdd(y, bias);
}
if (dataFormat === 'channelsFirst') {
y = transpose(y, [0, 4, 1, 2, 3]);
}
return y;
});
}
/**
* Abstract convolution layer.
*/
var BaseConv = /*#__PURE__*/function (_Layer) {
_inheritsLoose(BaseConv, _Layer);
function BaseConv(rank, args) {
var _this;
_this = _Layer.call(this, args) || this;
_this.bias = null;
_this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
_this.DEFAULT_BIAS_INITIALIZER = 'zeros';
BaseConv.verifyArgs(args);
_this.rank = rank;
assertPositiveInteger(_this.rank, 'rank');
if (_this.rank !== 1 && _this.rank !== 2 && _this.rank !== 3) {
throw new NotImplementedError("Convolution layer for rank other than 1, 2, or 3 (" + _this.rank + ") is " + "not implemented yet.");
}
_this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
_this.strides = normalizeArray(args.strides == null ? 1 : args.strides, rank, 'strides');
_this.padding = args.padding == null ? 'valid' : args.padding;
checkPaddingMode(_this.padding);
_this.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(_this.dataFormat);
_this.activation = getActivation(args.activation);
_this.useBias = args.useBias == null ? true : args.useBias;
_this.biasInitializer = getInitializer(args.biasInitializer || _this.DEFAULT_BIAS_INITIALIZER);
_this.biasConstraint = getConstraint(args.biasConstraint);
_this.biasRegularizer = getRegularizer(args.biasRegularizer);
_this.activityRegularizer = getRegularizer(args.activityRegularizer);
_this.dilationRate = normalizeArray(args.dilationRate == null ? 1 : args.dilationRate, rank, 'dilationRate');
if (_this.rank === 1 && Array.isArray(_this.dilationRate) && _this.dilationRate.length !== 1) {
throw new ValueError("dilationRate must be a number or an array of a single number " + "for 1D convolution, but received " + ("" + JSON.stringify(_this.dilationRate)));
} else if (_this.rank === 2) {
if (typeof _this.dilationRate === 'number') {
_this.dilationRate = [_this.dilationRate, _this.dilationRate];
} else if (_this.dilationRate.length !== 2) {
throw new ValueError("dilationRate must be a number or array of two numbers for 2D " + ("convolution, but received " + JSON.stringify(_this.dilationRate)));
}
} else if (_this.rank === 3) {
if (typeof _this.dilationRate === 'number') {
_this.dilationRate = [_this.dilationRate, _this.dilationRate, _this.dilationRate];
} else if (_this.dilationRate.length !== 3) {
throw new ValueError("dilationRate must be a number or array of three numbers for 3D " + ("convolution, but received " + JSON.stringify(_this.dilationRate)));
}
}
return _this;
}
BaseConv.verifyArgs = function verifyArgs(args) {
// Check config.kernelSize type and shape.
assert$1('kernelSize' in args, "required key 'kernelSize' not in config");
if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 3)) {
throw new ValueError("BaseConv expects config.kernelSize to be number or number[] with " + ("length 1, 2, or 3, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
var _proto = BaseConv.prototype;
_proto.getConfig = function getConfig() {
var config = {
kernelSize: this.kernelSize,
strides: this.strides,
padding: this.padding,
dataFormat: this.dataFormat,
dilationRate: this.dilationRate,
activation: serializeActivation(this.activation),
useBias: this.useBias,
biasInitializer: serializeInitializer(this.biasInitializer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
biasConstraint: serializeConstraint(this.biasConstraint)
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return BaseConv;
}(Layer);
/**
* Abstract nD convolution layer. Ancestor of convolution layers which reduce
* across channels, i.e., Conv1D and Conv2D, but not DepthwiseConv2D.
*/
var Conv = /*#__PURE__*/function (_BaseConv) {
_inheritsLoose(Conv, _BaseConv);
function Conv(rank, args) {
var _this2;
_this2 = _BaseConv.call(this, rank, args) || this;
_this2.kernel = null;
Conv.verifyArgs(args);
_this2.filters = args.filters;
assertPositiveInteger(_this2.filters, 'filters');
_this2.kernelInitializer = getInitializer(args.kernelInitializer || _this2.DEFAULT_KERNEL_INITIALIZER);
_this2.kernelConstraint = getConstraint(args.kernelConstraint);
_this2.kernelRegularizer = getRegularizer(args.kernelRegularizer);
return _this2;
}
var _proto2 = Conv.prototype;
_proto2.build = function build(inputShape) {
var _axes;
inputShape = getExactlyOneShape(inputShape);
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError("The channel dimension of the input should be defined. " + ("Found " + inputShape[channelAxis]));
}
var inputDim = inputShape[channelAxis];
var kernelShape = this.kernelSize.concat([inputDim, this.filters]);
this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.filters], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
this.inputSpec = [{
ndim: this.rank + 2,
axes: (_axes = {}, _axes[channelAxis] = inputDim, _axes)
}];
this.built = true;
};
_proto2.call = function call(inputs, kwargs) {
var _this3 = this;
return tidy(function () {
inputs = getExactlyOneTensor(inputs);
var outputs;
var biasValue = _this3.bias == null ? null : _this3.bias.read();
var fusedActivationName = mapActivationToFusedKernel(_this3.activation.getClassName());
if (fusedActivationName != null && _this3.rank === 2) {
outputs = conv2dWithBiasActivation(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate, fusedActivationName);
} else {
if (_this3.rank === 1) {
outputs = conv1dWithBias(inputs, _this3.kernel.read(), biasValue, _this3.strides[0], _this3.padding, _this3.dataFormat, _this3.dilationRate[0]);
} else if (_this3.rank === 2) {
// TODO(cais): Move up to constructor.
outputs = conv2dWithBiasActivation(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate);
} else if (_this3.rank === 3) {
outputs = conv3dWithBias(inputs, _this3.kernel.read(), biasValue, _this3.strides, _this3.padding, _this3.dataFormat, _this3.dilationRate);
} else {
throw new NotImplementedError('convolutions greater than 3D are not implemented yet.');
}
if (_this3.activation != null) {
outputs = _this3.activation.apply(outputs);
}
}
return outputs;
});
};
_proto2.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var newSpace = [];
var space = this.dataFormat === 'channelsLast' ? inputShape.slice(1, inputShape.length - 1) : inputShape.slice(2);
for (var i = 0; i < space.length; ++i) {
var newDim = convOutputLength(space[i], this.kernelSize[i], this.padding, this.strides[i], typeof this.dilationRate === 'number' ? this.dilationRate : this.dilationRate[i]);
newSpace.push(newDim);
}
var outputShape = [inputShape[0]];
if (this.dataFormat === 'channelsLast') {
outputShape = outputShape.concat(newSpace);
outputShape.push(this.filters);
} else {
outputShape.push(this.filters);
outputShape = outputShape.concat(newSpace);
}
return outputShape;
};
_proto2.getConfig = function getConfig() {
var config = {
filters: this.filters,
kernelInitializer: serializeInitializer(this.kernelInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint)
};
var baseConfig = _BaseConv.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
Conv.verifyArgs = function verifyArgs(args) {
// Check config.filters type, shape, and value.
if (!('filters' in args) || typeof args.filters !== 'number' || args.filters < 1) {
throw new ValueError("Convolution layer expected config.filters to be a 'number' > 0 " + ("but got " + JSON.stringify(args.filters)));
}
};
return Conv;
}(BaseConv);
var Conv2D$1 = /*#__PURE__*/function (_Conv) {
_inheritsLoose(Conv2D, _Conv);
function Conv2D(args) {
var _this4;
_this4 = _Conv.call(this, 2, args) || this;
Conv2D.verifyArgs(args);
return _this4;
}
var _proto3 = Conv2D.prototype;
_proto3.getConfig = function getConfig() {
var config = _Conv.prototype.getConfig.call(this);
delete config['rank'];
return config;
};
Conv2D.verifyArgs = function verifyArgs(args) {
// config.kernelSize must be a number or array of numbers.
if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 2)) {
throw new ValueError("Conv2D expects config.kernelSize to be number or number[] with " + ("length 1 or 2, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
return Conv2D;
}(Conv);
/** @nocollapse */
Conv2D$1.className = 'Conv2D';
registerClass(Conv2D$1);
var Conv3D$1 = /*#__PURE__*/function (_Conv2) {
_inheritsLoose(Conv3D, _Conv2);
function Conv3D(args) {
var _this5;
_this5 = _Conv2.call(this, 3, args) || this;
Conv3D.verifyArgs(args);
return _this5;
}
var _proto4 = Conv3D.prototype;
_proto4.getConfig = function getConfig() {
var config = _Conv2.prototype.getConfig.call(this);
delete config['rank'];
return config;
};
Conv3D.verifyArgs = function verifyArgs(args) {
// config.kernelSize must be a number or array of numbers.
if (typeof args.kernelSize !== 'number') {
if (!(Array.isArray(args.kernelSize) && (args.kernelSize.length === 1 || args.kernelSize.length === 3))) {
throw new ValueError("Conv3D expects config.kernelSize to be number or" + (" [number, number, number], but received " + JSON.stringify(args.kernelSize) + "."));
}
}
};
return Conv3D;
}(Conv);
/** @nocollapse */
Conv3D$1.className = 'Conv3D';
registerClass(Conv3D$1);
var Conv2DTranspose = /*#__PURE__*/function (_Conv2D) {
_inheritsLoose(Conv2DTranspose, _Conv2D);
function Conv2DTranspose(args) {
var _this6;
_this6 = _Conv2D.call(this, args) || this;
_this6.inputSpec = [new InputSpec({
ndim: 4
})];
if (_this6.padding !== 'same' && _this6.padding !== 'valid') {
throw new ValueError("Conv2DTranspose currently supports only padding modes 'same' " + ("and 'valid', but received padding mode " + _this6.padding));
}
return _this6;
}
var _proto5 = Conv2DTranspose.prototype;
_proto5.build = function build(inputShape) {
var _axes2;
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length !== 4) {
throw new ValueError('Input should have rank 4; Received input shape: ' + JSON.stringify(inputShape));
}
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.');
}
var inputDim = inputShape[channelAxis];
var kernelShape = this.kernelSize.concat([this.filters, inputDim]);
this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} // Set input spec.
this.inputSpec = [new InputSpec({
ndim: 4,
axes: (_axes2 = {}, _axes2[channelAxis] = inputDim, _axes2)
})];
this.built = true;
};
_proto5.call = function call(inputs, kwargs) {
var _this7 = this;
return tidy(function () {
var input = getExactlyOneTensor(inputs);
if (input.shape.length !== 4) {
throw new ValueError("Conv2DTranspose.call() expects input tensor to be rank-4, but " + ("received a tensor of rank-" + input.shape.length));
}
var inputShape = input.shape;
var batchSize = inputShape[0];
var hAxis;
var wAxis;
if (_this7.dataFormat === 'channelsFirst') {
hAxis = 2;
wAxis = 3;
} else {
hAxis = 1;
wAxis = 2;
}
var height = inputShape[hAxis];
var width = inputShape[wAxis];
var kernelH = _this7.kernelSize[0];
var kernelW = _this7.kernelSize[1];
var strideH = _this7.strides[0];
var strideW = _this7.strides[1]; // Infer the dynamic output shape.
var outHeight = deconvLength(height, strideH, kernelH, _this7.padding);
var outWidth = deconvLength(width, strideW, kernelW, _this7.padding); // Porting Note: We don't branch based on `this.dataFormat` here,
// because
// the tjfs-core function `conv2dTranspose` called below always
// assumes channelsLast.
var outputShape = [batchSize, outHeight, outWidth, _this7.filters];
if (_this7.dataFormat !== 'channelsLast') {
input = transpose(input, [0, 2, 3, 1]);
}
var outputs = conv2dTranspose(input, _this7.kernel.read(), outputShape, _this7.strides, _this7.padding);
if (_this7.dataFormat !== 'channelsLast') {
outputs = transpose(outputs, [0, 3, 1, 2]);
}
if (_this7.bias != null) {
outputs = biasAdd(outputs, _this7.bias.read(), _this7.dataFormat);
}
if (_this7.activation != null) {
outputs = _this7.activation.apply(outputs);
}
return outputs;
});
};
_proto5.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
var channelAxis;
var heightAxis;
var widthAxis;
if (this.dataFormat === 'channelsFirst') {
channelAxis = 1;
heightAxis = 2;
widthAxis = 3;
} else {
channelAxis = 3;
heightAxis = 1;
widthAxis = 2;
}
var kernelH = this.kernelSize[0];
var kernelW = this.kernelSize[1];
var strideH = this.strides[0];
var strideW = this.strides[1];
outputShape[channelAxis] = this.filters;
outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
return outputShape;
};
_proto5.getConfig = function getConfig() {
var config = _Conv2D.prototype.getConfig.call(this);
delete config['dilationRate'];
return config;
};
return Conv2DTranspose;
}(Conv2D$1);
/** @nocollapse */
Conv2DTranspose.className = 'Conv2DTranspose';
registerClass(Conv2DTranspose);
var Conv3DTranspose = /*#__PURE__*/function (_Conv3D) {
_inheritsLoose(Conv3DTranspose, _Conv3D);
function Conv3DTranspose(args) {
var _this8;
_this8 = _Conv3D.call(this, args) || this;
_this8.inputSpec = [new InputSpec({
ndim: 5
})];
if (_this8.padding !== 'same' && _this8.padding !== 'valid') {
throw new ValueError("Conv3DTranspose currently supports only padding modes 'same' " + ("and 'valid', but received padding mode " + _this8.padding));
}
return _this8;
}
var _proto6 = Conv3DTranspose.prototype;
_proto6.build = function build(inputShape) {
var _axes3;
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length !== 5) {
throw new ValueError('Input should have rank 5; Received input shape: ' + JSON.stringify(inputShape));
}
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError('The channel dimension of the inputs should be defined. ' + 'Found `None`.');
}
var inputDim = inputShape[channelAxis];
var kernelShape = this.kernelSize.concat([this.filters, inputDim]);
this.kernel = this.addWeight('kernel', kernelShape, 'float32', this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} // Set input spec.
this.inputSpec = [new InputSpec({
ndim: 5,
axes: (_axes3 = {}, _axes3[channelAxis] = inputDim, _axes3)
})];
this.built = true;
};
_proto6.call = function call(inputs, kwargs) {
var _this9 = this;
return tidy(function () {
var input = getExactlyOneTensor(inputs);
if (input.shape.length !== 5) {
throw new ValueError("Conv3DTranspose.call() expects input tensor to be rank-4, but " + ("received a tensor of rank-" + input.shape.length));
}
var inputShape = input.shape;
var batchSize = inputShape[0];
var hAxis;
var wAxis;
var dAxis;
if (_this9.dataFormat === 'channelsFirst') {
dAxis = 2;
hAxis = 3;
wAxis = 4;
} else {
dAxis = 1;
hAxis = 2;
wAxis = 3;
}
var depth = inputShape[dAxis];
var height = inputShape[hAxis];
var width = inputShape[wAxis];
var kernelD = _this9.kernelSize[0];
var kernelH = _this9.kernelSize[1];
var kernelW = _this9.kernelSize[2];
var strideD = _this9.strides[0];
var strideH = _this9.strides[1];
var strideW = _this9.strides[2]; // Infer the dynamic output shape.
var outDepth = deconvLength(depth, strideD, kernelD, _this9.padding);
var outHeight = deconvLength(height, strideH, kernelH, _this9.padding);
var outWidth = deconvLength(width, strideW, kernelW, _this9.padding); // Same as `conv2dTranspose`. We always assumes channelsLast.
var outputShape = [batchSize, outDepth, outHeight, outWidth, _this9.filters];
if (_this9.dataFormat !== 'channelsLast') {
input = transpose(input, [0, 2, 3, 4, 1]);
}
var outputs = conv3dTranspose(input, _this9.kernel.read(), outputShape, _this9.strides, _this9.padding);
if (_this9.dataFormat !== 'channelsLast') {
outputs = transpose(outputs, [0, 4, 1, 2, 3]);
}
if (_this9.bias !== null) {
outputs = biasAdd(outputs, _this9.bias.read(), _this9.dataFormat);
}
if (_this9.activation !== null) {
outputs = _this9.activation.apply(outputs);
}
return outputs;
});
};
_proto6.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
var channelAxis;
var depthAxis;
var heightAxis;
var widthAxis;
if (this.dataFormat === 'channelsFirst') {
channelAxis = 1;
depthAxis = 2;
heightAxis = 3;
widthAxis = 4;
} else {
channelAxis = 4;
depthAxis = 1;
heightAxis = 2;
widthAxis = 3;
}
var kernelD = this.kernelSize[0];
var kernelH = this.kernelSize[1];
var kernelW = this.kernelSize[2];
var strideD = this.strides[0];
var strideH = this.strides[1];
var strideW = this.strides[2];
outputShape[channelAxis] = this.filters;
outputShape[depthAxis] = deconvLength(outputShape[depthAxis], strideD, kernelD, this.padding);
outputShape[heightAxis] = deconvLength(outputShape[heightAxis], strideH, kernelH, this.padding);
outputShape[widthAxis] = deconvLength(outputShape[widthAxis], strideW, kernelW, this.padding);
return outputShape;
};
_proto6.getConfig = function getConfig() {
var config = _Conv3D.prototype.getConfig.call(this);
delete config['dilationRate'];
return config;
};
return Conv3DTranspose;
}(Conv3D$1);
/** @nocollapse */
Conv3DTranspose.className = 'Conv3DTranspose';
registerClass(Conv3DTranspose);
var SeparableConv = /*#__PURE__*/function (_Conv3) {
_inheritsLoose(SeparableConv, _Conv3);
function SeparableConv(rank, config) {
var _this10;
_this10 = _Conv3.call(this, rank, config) || this;
_this10.DEFAULT_DEPTHWISE_INITIALIZER = 'glorotUniform';
_this10.DEFAULT_POINTWISE_INITIALIZER = 'glorotUniform';
_this10.depthwiseKernel = null;
_this10.pointwiseKernel = null;
if (config.filters == null) {
throw new ValueError('The `filters` configuration field is required by SeparableConv, ' + 'but is unspecified.');
}
if (config.kernelInitializer != null || config.kernelRegularizer != null || config.kernelConstraint != null) {
throw new ValueError('Fields kernelInitializer, kernelRegularizer and kernelConstraint ' + 'are invalid for SeparableConv2D. Use depthwiseInitializer, ' + 'depthwiseRegularizer, depthwiseConstraint, pointwiseInitializer, ' + 'pointwiseRegularizer and pointwiseConstraint instead.');
}
if (config.padding != null && config.padding !== 'same' && config.padding !== 'valid') {
throw new ValueError("SeparableConv" + _this10.rank + "D supports only padding modes: " + ("'same' and 'valid', but received " + JSON.stringify(config.padding)));
}
_this10.depthMultiplier = config.depthMultiplier == null ? 1 : config.depthMultiplier;
_this10.depthwiseInitializer = getInitializer(config.depthwiseInitializer || _this10.DEFAULT_DEPTHWISE_INITIALIZER);
_this10.depthwiseRegularizer = getRegularizer(config.depthwiseRegularizer);
_this10.depthwiseConstraint = getConstraint(config.depthwiseConstraint);
_this10.pointwiseInitializer = getInitializer(config.depthwiseInitializer || _this10.DEFAULT_POINTWISE_INITIALIZER);
_this10.pointwiseRegularizer = getRegularizer(config.pointwiseRegularizer);
_this10.pointwiseConstraint = getConstraint(config.pointwiseConstraint);
return _this10;
}
var _proto7 = SeparableConv.prototype;
_proto7.build = function build(inputShape) {
var _axes4;
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < this.rank + 2) {
throw new ValueError("Inputs to SeparableConv" + this.rank + "D should have rank " + (this.rank + 2 + ", but received input shape: ") + ("" + JSON.stringify(inputShape)));
}
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
throw new ValueError("The channel dimension of the inputs should be defined, " + ("but found " + JSON.stringify(inputShape[channelAxis])));
}
var inputDim = inputShape[channelAxis];
var depthwiseKernelShape = this.kernelSize.concat([inputDim, this.depthMultiplier]);
var pointwiseKernelShape = [];
for (var i = 0; i < this.rank; ++i) {
pointwiseKernelShape.push(1);
}
pointwiseKernelShape.push(inputDim * this.depthMultiplier, this.filters);
var trainable = true;
this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, 'float32', this.depthwiseInitializer, this.depthwiseRegularizer, trainable, this.depthwiseConstraint);
this.pointwiseKernel = this.addWeight('pointwise_kernel', pointwiseKernelShape, 'float32', this.pointwiseInitializer, this.pointwiseRegularizer, trainable, this.pointwiseConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.filters], 'float32', this.biasInitializer, this.biasRegularizer, trainable, this.biasConstraint);
} else {
this.bias = null;
}
this.inputSpec = [new InputSpec({
ndim: this.rank + 2,
axes: (_axes4 = {}, _axes4[channelAxis] = inputDim, _axes4)
})];
this.built = true;
};
_proto7.call = function call(inputs, kwargs) {
var _this11 = this;
return tidy(function () {
inputs = getExactlyOneTensor(inputs);
var output;
if (_this11.rank === 1) {
throw new NotImplementedError('1D separable convolution is not implemented yet.');
} else if (_this11.rank === 2) {
if (_this11.dataFormat === 'channelsFirst') {
inputs = transpose(inputs, [0, 2, 3, 1]); // NCHW -> NHWC.
}
output = separableConv2d(inputs, _this11.depthwiseKernel.read(), _this11.pointwiseKernel.read(), _this11.strides, _this11.padding, _this11.dilationRate, 'NHWC');
}
if (_this11.useBias) {
output = biasAdd(output, _this11.bias.read(), _this11.dataFormat);
}
if (_this11.activation != null) {
output = _this11.activation.apply(output);
}
if (_this11.dataFormat === 'channelsFirst') {
output = transpose(output, [0, 3, 1, 2]); // NHWC -> NCHW.
}
return output;
});
};
_proto7.getConfig = function getConfig() {
var config = _Conv3.prototype.getConfig.call(this);
delete config['rank'];
delete config['kernelInitializer'];
delete config['kernelRegularizer'];
delete config['kernelConstraint'];
config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer);
config['pointwiseInitializer'] = serializeInitializer(this.pointwiseInitializer);
config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer);
config['pointwiseRegularizer'] = serializeRegularizer(this.pointwiseRegularizer);
config['depthwiseConstraint'] = serializeConstraint(this.depthwiseConstraint);
config['pointwiseConstraint'] = serializeConstraint(this.pointwiseConstraint);
return config;
};
return SeparableConv;
}(Conv);
/** @nocollapse */
SeparableConv.className = 'SeparableConv';
var SeparableConv2D = /*#__PURE__*/function (_SeparableConv) {
_inheritsLoose(SeparableConv2D, _SeparableConv);
function SeparableConv2D(args) {
return _SeparableConv.call(this, 2, args) || this;
}
return SeparableConv2D;
}(SeparableConv);
/** @nocollapse */
SeparableConv2D.className = 'SeparableConv2D';
registerClass(SeparableConv2D);
var Conv1D = /*#__PURE__*/function (_Conv4) {
_inheritsLoose(Conv1D, _Conv4);
function Conv1D(args) {
var _this12;
_this12 = _Conv4.call(this, 1, args) || this;
Conv1D.verifyArgs(args);
_this12.inputSpec = [{
ndim: 3
}];
return _this12;
}
var _proto8 = Conv1D.prototype;
_proto8.getConfig = function getConfig() {
var config = _Conv4.prototype.getConfig.call(this);
delete config['rank'];
delete config['dataFormat'];
return config;
};
Conv1D.verifyArgs = function verifyArgs(args) {
// config.kernelSize must be a number or array of numbers.
if (typeof args.kernelSize !== 'number' && !checkArrayTypeAndLength(args.kernelSize, 'number', 1, 1)) {
throw new ValueError("Conv1D expects config.kernelSize to be number or number[] with " + ("length 1, but received " + JSON.stringify(args.kernelSize) + "."));
}
};
return Conv1D;
}(Conv);
/** @nocollapse */
Conv1D.className = 'Conv1D';
registerClass(Conv1D);
var Cropping2D = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(Cropping2D, _Layer2);
function Cropping2D(args) {
var _this13;
_this13 = _Layer2.call(this, args) || this;
if (typeof args.cropping === 'number') {
_this13.cropping = [[args.cropping, args.cropping], [args.cropping, args.cropping]];
} else if (typeof args.cropping[0] === 'number') {
_this13.cropping = [[args.cropping[0], args.cropping[0]], [args.cropping[1], args.cropping[1]]];
} else {
_this13.cropping = args.cropping;
}
_this13.dataFormat = args.dataFormat === undefined ? 'channelsLast' : args.dataFormat;
_this13.inputSpec = [{
ndim: 4
}];
return _this13;
}
var _proto9 = Cropping2D.prototype;
_proto9.computeOutputShape = function computeOutputShape(inputShape) {
if (this.dataFormat === 'channelsFirst') {
return [inputShape[0], inputShape[1], inputShape[2] - this.cropping[0][0] - this.cropping[0][1], inputShape[3] - this.cropping[1][0] - this.cropping[1][1]];
} else {
return [inputShape[0], inputShape[1] - this.cropping[0][0] - this.cropping[0][1], inputShape[2] - this.cropping[1][0] - this.cropping[1][1], inputShape[3]];
}
};
_proto9.call = function call(inputs, kwargs) {
var _this14 = this;
return tidy(function () {
inputs = getExactlyOneTensor(inputs);
if (_this14.dataFormat === 'channelsLast') {
var hSliced = sliceAlongAxis(inputs, _this14.cropping[0][0], inputs.shape[1] - _this14.cropping[0][0] - _this14.cropping[0][1], 2);
return sliceAlongAxis(hSliced, _this14.cropping[1][0], inputs.shape[2] - _this14.cropping[1][1] - _this14.cropping[1][0], 3);
} else {
var _hSliced = sliceAlongAxis(inputs, _this14.cropping[0][0], inputs.shape[2] - _this14.cropping[0][0] - _this14.cropping[0][1], 3);
return sliceAlongAxis(_hSliced, _this14.cropping[1][0], inputs.shape[3] - _this14.cropping[1][1] - _this14.cropping[1][0], 4);
}
});
};
_proto9.getConfig = function getConfig() {
var config = {
cropping: this.cropping,
dataFormat: this.dataFormat
};
var baseConfig = _Layer2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Cropping2D;
}(Layer);
/** @nocollapse */
Cropping2D.className = 'Cropping2D';
registerClass(Cropping2D);
var UpSampling2D = /*#__PURE__*/function (_Layer3) {
_inheritsLoose(UpSampling2D, _Layer3);
function UpSampling2D(args) {
var _this15;
_this15 = _Layer3.call(this, args) || this;
_this15.DEFAULT_SIZE = [2, 2];
_this15.inputSpec = [{
ndim: 4
}];
_this15.size = args.size == null ? _this15.DEFAULT_SIZE : args.size;
_this15.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(_this15.dataFormat);
_this15.interpolation = args.interpolation == null ? 'nearest' : args.interpolation;
checkInterpolationFormat(_this15.interpolation);
return _this15;
}
var _proto10 = UpSampling2D.prototype;
_proto10.computeOutputShape = function computeOutputShape(inputShape) {
if (this.dataFormat === 'channelsFirst') {
var height = inputShape[2] == null ? null : this.size[0] * inputShape[2];
var width = inputShape[3] == null ? null : this.size[1] * inputShape[3];
return [inputShape[0], inputShape[1], height, width];
} else {
var _height = inputShape[1] == null ? null : this.size[0] * inputShape[1];
var _width = inputShape[2] == null ? null : this.size[1] * inputShape[2];
return [inputShape[0], _height, _width, inputShape[3]];
}
};
_proto10.call = function call(inputs, kwargs) {
var _this16 = this;
return tidy(function () {
var input = getExactlyOneTensor(inputs);
var inputShape = input.shape;
if (_this16.dataFormat === 'channelsFirst') {
input = transpose(input, [0, 2, 3, 1]);
var height = _this16.size[0] * inputShape[2];
var width = _this16.size[1] * inputShape[3];
var resized = _this16.interpolation === 'nearest' ? image.resizeNearestNeighbor(input, [height, width]) : image.resizeBilinear(input, [height, width]);
return transpose(resized, [0, 3, 1, 2]);
} else {
var _height2 = _this16.size[0] * inputShape[1];
var _width2 = _this16.size[1] * inputShape[2];
return _this16.interpolation === 'nearest' ? image.resizeNearestNeighbor(input, [_height2, _width2]) : image.resizeBilinear(input, [_height2, _width2]);
}
});
};
_proto10.getConfig = function getConfig() {
var config = {
size: this.size,
dataFormat: this.dataFormat
};
var baseConfig = _Layer3.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return UpSampling2D;
}(Layer);
/** @nocollapse */
UpSampling2D.className = 'UpSampling2D';
registerClass(UpSampling2D);
/**
* 2D convolution with separable filters.
* @param x Input tensor.
* @param depthwiseKernel Convolution kernel for depthwise convolution.
* @param strides Strides (Array of two integers).
* @param padding Padding model.
* @param dataFormat Data format.
* @param dilationRate Array of two integers, dilation rates for the separable
* convolution.
* @returns Output tensor.
* @throws ValueError If depthwiseKernel is not a 4D array.
*/
function depthwiseConv2d$2(x, depthwiseKernel, strides, padding, dataFormat, dilationRate) {
if (strides === void 0) {
strides = [1, 1];
}
if (padding === void 0) {
padding = 'valid';
}
return tidy(function () {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
var y = preprocessConv2DInput(x, dataFormat);
if (x.rank !== 4) {
throw new ValueError("Input for depthwiseConv2d is required to be 4-D, but is instead " + (x.rank + "-D"));
}
if (depthwiseKernel.rank !== 4) {
throw new ValueError("depthwiseKernel is required to be 4-D, but is instead " + (depthwiseKernel.rank + "-D"));
}
y = depthwiseConv2d(y, depthwiseKernel, strides, padding === 'same' ? 'same' : 'valid', 'NHWC', dilationRate);
if (dataFormat === 'channelsFirst') {
y = transpose(y, [0, 3, 1, 2]);
}
return y;
});
}
var DepthwiseConv2D = /*#__PURE__*/function (_BaseConv) {
_inheritsLoose(DepthwiseConv2D, _BaseConv);
function DepthwiseConv2D(args) {
var _this;
_this = _BaseConv.call(this, 2, args) || this;
_this.depthwiseKernel = null;
_this.depthMultiplier = args.depthMultiplier == null ? 1 : args.depthMultiplier;
_this.depthwiseInitializer = getInitializer(args.depthwiseInitializer || _this.DEFAULT_KERNEL_INITIALIZER);
_this.depthwiseConstraint = getConstraint(args.depthwiseConstraint);
_this.depthwiseRegularizer = getRegularizer(args.depthwiseRegularizer);
return _this;
}
var _proto = DepthwiseConv2D.prototype;
_proto.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < 4) {
throw new ValueError("Inputs to DepthwiseConv2D should have rank 4. " + ("Received input shape: " + JSON.stringify(inputShape) + "."));
}
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : 3;
if (inputShape[channelAxis] == null || inputShape[channelAxis] < 0) {
throw new ValueError('The channel dimension of the inputs to DepthwiseConv2D should ' + ("be defined, but is not (" + inputShape[channelAxis] + ")."));
}
var inputDim = inputShape[channelAxis];
var depthwiseKernelShape = [this.kernelSize[0], this.kernelSize[1], inputDim, this.depthMultiplier];
this.depthwiseKernel = this.addWeight('depthwise_kernel', depthwiseKernelShape, null, this.depthwiseInitializer, this.depthwiseRegularizer, true, this.depthwiseConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [inputDim * this.depthMultiplier], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
inputs = getExactlyOneTensor(inputs);
var outputs = depthwiseConv2d$2(inputs, _this2.depthwiseKernel.read(), _this2.strides, _this2.padding, _this2.dataFormat, null); // TODO(cais): Add support for dilation.
if (_this2.useBias) {
outputs = biasAdd(outputs, _this2.bias.read(), _this2.dataFormat);
}
if (_this2.activation != null) {
outputs = _this2.activation.apply(outputs);
}
return outputs;
});
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
var outFilters = this.dataFormat === 'channelsFirst' ? inputShape[1] * this.depthMultiplier : inputShape[3] * this.depthMultiplier;
var outRows = convOutputLength(rows, this.kernelSize[0], this.padding, this.strides[0]);
var outCols = convOutputLength(cols, this.kernelSize[1], this.padding, this.strides[1]);
if (this.dataFormat === 'channelsFirst') {
return [inputShape[0], outFilters, outRows, outCols];
} else {
// In this case, assume 'channelsLast'.
return [inputShape[0], outRows, outCols, outFilters];
}
};
_proto.getConfig = function getConfig() {
var config = _BaseConv.prototype.getConfig.call(this);
config['depthMultiplier'] = this.depthMultiplier;
config['depthwiseInitializer'] = serializeInitializer(this.depthwiseInitializer);
config['depthwiseRegularizer'] = serializeRegularizer(this.depthwiseRegularizer);
config['depthwiseConstraint'] = serializeConstraint(this.depthwiseRegularizer);
return config;
};
return DepthwiseConv2D;
}(BaseConv);
/** @nocollapse */
DepthwiseConv2D.className = 'DepthwiseConv2D';
registerClass(DepthwiseConv2D);
/**
* Standardize `apply()` args to a single list of tensor inputs.
*
* When running a model loaded from file, the input tensors `initialState` and
* `constants` are passed to `RNN.apply()` as part of `inputs` instead of the
* dedicated kwargs fields. `inputs` consists of
* `[inputs, initialState0, initialState1, ..., constant0, constant1]` in this
* case.
* This method makes sure that arguments are
* separated and that `initialState` and `constants` are `Array`s of tensors
* (or None).
*
* @param inputs Tensor or `Array` of tensors.
* @param initialState Tensor or `Array` of tensors or `null`/`undefined`.
* @param constants Tensor or `Array` of tensors or `null`/`undefined`.
* @returns An object consisting of
* inputs: A tensor.
* initialState: `Array` of tensors or `null`.
* constants: `Array` of tensors or `null`.
* @throws ValueError, if `inputs` is an `Array` but either `initialState` or
* `constants` is provided.
*/
function standardizeArgs(inputs, initialState, constants, numConstants) {
if (Array.isArray(inputs)) {
if (initialState != null || constants != null) {
throw new ValueError('When inputs is an array, neither initialState or constants ' + 'should be provided');
}
if (numConstants != null) {
constants = inputs.slice(inputs.length - numConstants, inputs.length);
inputs = inputs.slice(0, inputs.length - numConstants);
}
if (inputs.length > 1) {
initialState = inputs.slice(1, inputs.length);
}
inputs = inputs[0];
}
function toListOrNull(x) {
if (x == null || Array.isArray(x)) {
return x;
} else {
return [x];
}
}
initialState = toListOrNull(initialState);
constants = toListOrNull(constants);
return {
inputs: inputs,
initialState: initialState,
constants: constants
};
}
/**
* Iterates over the time dimension of a tensor.
*
* @param stepFunction RNN step function.
* Parameters:
* inputs: tensor with shape `[samples, ...]` (no time dimension),
* representing input for the batch of samples at a certain time step.
* states: an Array of tensors.
* Returns:
* outputs: tensor with shape `[samples, outputDim]` (no time dimension).
* newStates: list of tensors, same length and shapes as `states`. The first
* state in the list must be the output tensor at the previous timestep.
* @param inputs Tensor of temporal data of shape `[samples, time, ...]` (at
* least 3D).
* @param initialStates Tensor with shape `[samples, outputDim]` (no time
* dimension), containing the initial values of the states used in the step
* function.
* @param goBackwards If `true`, do the iteration over the time dimension in
* reverse order and return the reversed sequence.
* @param mask Binary tensor with shape `[sample, time, 1]`, with a zero for
* every element that is masked.
* @param constants An Array of constant values passed at each step.
* @param unroll Whether to unroll the RNN or to use a symbolic loop. *Not*
* applicable to this imperative deeplearn.js backend. Its value is ignored.
* @param needPerStepOutputs Whether the per-step outputs are to be
* concatenated into a single tensor and returned (as the second return
* value). Default: `false`. This arg is included so that the relatively
* expensive concatenation of the stepwise outputs can be omitted unless
* the stepwise outputs need to be kept (e.g., for an LSTM layer of which
* `returnSequence` is `true`.)
* @returns An Array: `[lastOutput, outputs, newStates]`.
* lastOutput: the lastest output of the RNN, of shape `[samples, ...]`.
* outputs: tensor with shape `[samples, time, ...]` where each entry
* `output[s, t]` is the output of the step function at time `t` for sample
* `s`. This return value is provided if and only if the
* `needPerStepOutputs` is set as `true`. If it is set as `false`, this
* return value will be `undefined`.
* newStates: Array of tensors, latest states returned by the step function,
* of shape `(samples, ...)`.
* @throws ValueError If input dimension is less than 3.
*
* TODO(nielsene): This needs to be tidy-ed.
*/
function rnn(stepFunction, inputs, initialStates, goBackwards, mask, constants, unroll, needPerStepOutputs) {
if (goBackwards === void 0) {
goBackwards = false;
}
if (unroll === void 0) {
unroll = false;
}
if (needPerStepOutputs === void 0) {
needPerStepOutputs = false;
}
return tidy(function () {
var ndim = inputs.shape.length;
if (ndim < 3) {
throw new ValueError("Input should be at least 3D, but is " + ndim + "D.");
} // Transpose to time-major, i.e., from [batch, time, ...] to [time, batch,
// ...].
var axes = [1, 0].concat(range$1(2, ndim));
inputs = transpose(inputs, axes);
if (constants != null) {
throw new NotImplementedError('The rnn() functoin of the deeplearn.js backend does not support ' + 'constants yet.');
} // Porting Note: the unroll option is ignored by the imperative backend.
if (unroll) {
console.warn('Backend rnn(): the unroll = true option is not applicable to the ' + 'imperative deeplearn.js backend.');
}
if (mask != null) {
mask = cast(cast(mask, 'bool'), 'float32');
if (mask.rank === ndim - 1) {
mask = expandDims(mask, -1);
}
mask = transpose(mask, axes);
}
if (goBackwards) {
inputs = reverse(inputs, 0);
if (mask != null) {
mask = reverse(mask, 0);
}
} // Porting Note: PyKeras with TensorFlow backend uses a symbolic loop
// (tf.while_loop). But for the imperative deeplearn.js backend, we just
// use the usual TypeScript control flow to iterate over the time steps in
// the inputs.
// Porting Note: PyKeras patches a "_use_learning_phase" attribute to
// outputs.
// This is not idiomatic in TypeScript. The info regarding whether we are
// in a learning (i.e., training) phase for RNN is passed in a different
// way.
var perStepOutputs = [];
var lastOutput;
var states = initialStates;
var timeSteps = inputs.shape[0];
var perStepInputs = unstack(inputs);
var perStepMasks;
if (mask != null) {
perStepMasks = unstack(mask);
}
var _loop = function _loop(t) {
var currentInput = perStepInputs[t];
var stepOutputs = tidy(function () {
return stepFunction(currentInput, states);
});
if (mask == null) {
lastOutput = stepOutputs[0];
states = stepOutputs[1];
} else {
var maskedOutputs = tidy(function () {
var stepMask = perStepMasks[t];
var negStepMask = sub(onesLike(stepMask), stepMask); // TODO(cais): Would tfc.where() be better for performance?
var output = add$1(mul(stepOutputs[0], stepMask), mul(states[0], negStepMask));
var newStates = states.map(function (state, i) {
return add$1(mul(stepOutputs[1][i], stepMask), mul(state, negStepMask));
});
return {
output: output,
newStates: newStates
};
});
lastOutput = maskedOutputs.output;
states = maskedOutputs.newStates;
}
if (needPerStepOutputs) {
perStepOutputs.push(lastOutput);
}
};
for (var t = 0; t < timeSteps; ++t) {
_loop(t);
}
var outputs;
if (needPerStepOutputs) {
var axis = 1;
outputs = stack(perStepOutputs, axis);
}
return [lastOutput, outputs, states];
});
}
var RNN = /*#__PURE__*/function (_Layer) {
_inheritsLoose(RNN, _Layer);
function RNN(args) {
var _this;
_this = _Layer.call(this, args) || this;
var cell;
if (args.cell == null) {
throw new ValueError('cell property is missing for the constructor of RNN.');
} else if (Array.isArray(args.cell)) {
cell = new StackedRNNCells({
cells: args.cell
});
} else {
cell = args.cell;
}
if (cell.stateSize == null) {
throw new ValueError('The RNN cell should have an attribute `stateSize` (tuple of ' + 'integers, one integer per RNN state).');
}
_this.cell = cell;
_this.returnSequences = args.returnSequences == null ? false : args.returnSequences;
_this.returnState = args.returnState == null ? false : args.returnState;
_this.goBackwards = args.goBackwards == null ? false : args.goBackwards;
_this._stateful = args.stateful == null ? false : args.stateful;
_this.unroll = args.unroll == null ? false : args.unroll;
_this.supportsMasking = true;
_this.inputSpec = [new InputSpec({
ndim: 3
})];
_this.stateSpec = null;
_this.states_ = null; // TODO(cais): Add constantsSpec and numConstants.
_this.numConstants = null; // TODO(cais): Look into the use of initial_state in the kwargs of the
// constructor.
_this.keptStates = [];
return _this;
} // Porting Note: This is the equivalent of `RNN.states` property getter in
// PyKeras.
var _proto = RNN.prototype;
_proto.getStates = function getStates() {
if (this.states_ == null) {
var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
return range$1(0, numStates).map(function (x) {
return null;
});
} else {
return this.states_;
}
} // Porting Note: This is the equivalent of the `RNN.states` property setter in
// PyKeras.
;
_proto.setStates = function setStates(states) {
this.states_ = states;
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape; // TODO(cais): Remove the casting once stacked RNN cells become supported.
var stateSize = this.cell.stateSize;
if (!Array.isArray(stateSize)) {
stateSize = [stateSize];
}
var outputDim = stateSize[0];
var outputShape;
if (this.returnSequences) {
outputShape = [inputShape[0], inputShape[1], outputDim];
} else {
outputShape = [inputShape[0], outputDim];
}
if (this.returnState) {
var stateShape = [];
for (var _iterator = _createForOfIteratorHelperLoose(stateSize), _step; !(_step = _iterator()).done;) {
var dim = _step.value;
stateShape.push([inputShape[0], dim]);
}
return [outputShape].concat(stateShape);
} else {
return outputShape;
}
};
_proto.computeMask = function computeMask(inputs, mask) {
var _this2 = this;
return tidy(function () {
if (Array.isArray(mask)) {
mask = mask[0];
}
var outputMask = _this2.returnSequences ? mask : null;
if (_this2.returnState) {
var stateMask = _this2.states.map(function (s) {
return null;
});
return [outputMask].concat(stateMask);
} else {
return outputMask;
}
});
}
/**
* Get the current state tensors of the RNN.
*
* If the state hasn't been set, return an array of `null`s of the correct
* length.
*/
;
_proto.build = function build(inputShape) {
// Note inputShape will be an Array of Shapes of initial states and
// constants if these are passed in apply().
var constantShape = null;
if (this.numConstants != null) {
throw new NotImplementedError('Constants support is not implemented in RNN yet.');
}
if (isArrayOfShapes(inputShape)) {
inputShape = inputShape[0];
}
inputShape = inputShape;
var batchSize = this.stateful ? inputShape[0] : null;
var inputDim = inputShape.slice(2);
this.inputSpec[0] = new InputSpec({
shape: [batchSize, null].concat(inputDim)
}); // Allow cell (if RNNCell Layer) to build before we set or validate
// stateSpec.
var stepInputShape = [inputShape[0]].concat(inputShape.slice(2));
if (constantShape != null) {
throw new NotImplementedError('Constants support is not implemented in RNN yet.');
} else {
this.cell.build(stepInputShape);
} // Set or validate stateSpec.
var stateSize;
if (Array.isArray(this.cell.stateSize)) {
stateSize = this.cell.stateSize;
} else {
stateSize = [this.cell.stateSize];
}
if (this.stateSpec != null) {
if (!arraysEqual(this.stateSpec.map(function (spec) {
return spec.shape[spec.shape.length - 1];
}), stateSize)) {
throw new ValueError("An initialState was passed that is not compatible with " + ("cell.stateSize. Received stateSpec=" + this.stateSpec + "; ") + ("However cell.stateSize is " + this.cell.stateSize));
}
} else {
this.stateSpec = stateSize.map(function (dim) {
return new InputSpec({
shape: [null, dim]
});
});
}
if (this.stateful) {
this.resetStates();
}
}
/**
* Reset the state tensors of the RNN.
*
* If the `states` argument is `undefined` or `null`, will set the
* state tensor(s) of the RNN to all-zero tensors of the appropriate
* shape(s).
*
* If `states` is provided, will set the state tensors of the RNN to its
* value.
*
* @param states Optional externally-provided initial states.
* @param training Whether this call is done during training. For stateful
* RNNs, this affects whether the old states are kept or discarded. In
* particular, if `training` is `true`, the old states will be kept so
* that subsequent backpropgataion through time (BPTT) may work properly.
* Else, the old states will be discarded.
*/
;
_proto.resetStates = function resetStates(states, training) {
var _this3 = this;
if (training === void 0) {
training = false;
}
tidy(function () {
if (!_this3.stateful) {
throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
}
var batchSize = _this3.inputSpec[0].shape[0];
if (batchSize == null) {
throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.');
} // Initialize state if null.
if (_this3.states_ == null) {
if (Array.isArray(_this3.cell.stateSize)) {
_this3.states_ = _this3.cell.stateSize.map(function (dim) {
return zeros([batchSize, dim]);
});
} else {
_this3.states_ = [zeros([batchSize, _this3.cell.stateSize])];
}
} else if (states == null) {
// Dispose old state tensors.
dispose(_this3.states_); // For stateful RNNs, fully dispose kept old states.
if (_this3.keptStates != null) {
dispose(_this3.keptStates);
_this3.keptStates = [];
}
if (Array.isArray(_this3.cell.stateSize)) {
_this3.states_ = _this3.cell.stateSize.map(function (dim) {
return zeros([batchSize, dim]);
});
} else {
_this3.states_[0] = zeros([batchSize, _this3.cell.stateSize]);
}
} else {
if (!Array.isArray(states)) {
states = [states];
}
if (states.length !== _this3.states_.length) {
throw new ValueError("Layer " + _this3.name + " expects " + _this3.states_.length + " state(s), " + ("but it received " + states.length + " state value(s). Input ") + ("received: " + states));
}
if (training === true) {
// Store old state tensors for complete disposal later, i.e., during
// the next no-arg call to this method. We do not dispose the old
// states immediately because that BPTT (among other things) require
// them.
_this3.keptStates.push(_this3.states_.slice());
} else {
dispose(_this3.states_);
}
for (var index = 0; index < _this3.states_.length; ++index) {
var value = states[index];
var dim = Array.isArray(_this3.cell.stateSize) ? _this3.cell.stateSize[index] : _this3.cell.stateSize;
var expectedShape = [batchSize, dim];
if (!arraysEqual(value.shape, expectedShape)) {
throw new ValueError("State " + index + " is incompatible with layer " + _this3.name + ": " + ("expected shape=" + expectedShape + ", received shape=" + value.shape));
}
_this3.states_[index] = value;
}
}
_this3.states_ = _this3.states_.map(function (state) {
return keep(state.clone());
});
});
};
_proto.apply = function apply(inputs, kwargs) {
// TODO(cais): Figure out whether initialState is in kwargs or inputs.
var initialState = kwargs == null ? null : kwargs['initialState'];
var constants = kwargs == null ? null : kwargs['constants'];
if (kwargs == null) {
kwargs = {};
}
var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
inputs = standardized.inputs;
initialState = standardized.initialState;
constants = standardized.constants; // If any of `initial_state` or `constants` are specified and are
// `tf.SymbolicTensor`s, then add them to the inputs and temporarily modify
// the input_spec to include them.
var additionalInputs = [];
var additionalSpecs = [];
if (initialState != null) {
kwargs['initialState'] = initialState;
additionalInputs = additionalInputs.concat(initialState);
this.stateSpec = [];
for (var _iterator2 = _createForOfIteratorHelperLoose(initialState), _step2; !(_step2 = _iterator2()).done;) {
var state = _step2.value;
this.stateSpec.push(new InputSpec({
shape: state.shape
}));
} // TODO(cais): Use the following instead.
// this.stateSpec = initialState.map(state => new InputSpec({shape:
// state.shape}));
additionalSpecs = additionalSpecs.concat(this.stateSpec);
}
if (constants != null) {
kwargs['constants'] = constants;
additionalInputs = additionalInputs.concat(constants); // TODO(cais): Add this.constantsSpec.
this.numConstants = constants.length;
}
var isTensor = additionalInputs[0] instanceof SymbolicTensor;
if (isTensor) {
// Compute full input spec, including state and constants.
var fullInput = [inputs].concat(additionalInputs);
var fullInputSpec = this.inputSpec.concat(additionalSpecs); // Perform the call with temporarily replaced inputSpec.
var originalInputSpec = this.inputSpec;
this.inputSpec = fullInputSpec;
var output = _Layer.prototype.apply.call(this, fullInput, kwargs);
this.inputSpec = originalInputSpec;
return output;
} else {
return _Layer.prototype.apply.call(this, inputs, kwargs);
}
} // tslint:disable-next-line:no-any
;
_proto.call = function call(inputs, kwargs) {
var _this4 = this;
// Input shape: `[samples, time (padded with zeros), input_dim]`.
// Note that the .build() method of subclasses **must** define
// this.inputSpec and this.stateSpec owith complete input shapes.
return tidy(function () {
var mask = kwargs == null ? null : kwargs['mask'];
var training = kwargs == null ? null : kwargs['training'];
var initialState = kwargs == null ? null : kwargs['initialState'];
inputs = getExactlyOneTensor(inputs);
if (initialState == null) {
if (_this4.stateful) {
initialState = _this4.states_;
} else {
initialState = _this4.getInitialState(inputs);
}
}
var numStates = Array.isArray(_this4.cell.stateSize) ? _this4.cell.stateSize.length : 1;
if (initialState.length !== numStates) {
throw new ValueError("RNN Layer has " + numStates + " state(s) but was passed " + (initialState.length + " initial state(s)."));
}
if (_this4.unroll) {
console.warn('Ignoring unroll = true for RNN layer, due to imperative backend.');
}
var cellCallKwargs = {
training: training
}; // TODO(cais): Add support for constants.
var step = function step(inputs, states) {
// `inputs` and `states` are concatenated to form a single `Array` of
// `tf.Tensor`s as the input to `cell.call()`.
var outputs = _this4.cell.call([inputs].concat(states), cellCallKwargs); // Marshall the return value into output and new states.
return [outputs[0], outputs.slice(1)];
}; // TODO(cais): Add support for constants.
var rnnOutputs = rnn(step, inputs, initialState, _this4.goBackwards, mask, null, _this4.unroll, _this4.returnSequences);
var lastOutput = rnnOutputs[0];
var outputs = rnnOutputs[1];
var states = rnnOutputs[2];
if (_this4.stateful) {
_this4.resetStates(states, training);
}
var output = _this4.returnSequences ? outputs : lastOutput; // TODO(cais): Porperty set learning phase flag.
if (_this4.returnState) {
return [output].concat(states);
} else {
return output;
}
});
};
_proto.getInitialState = function getInitialState(inputs) {
var _this5 = this;
return tidy(function () {
// Build an all-zero tensor of shape [samples, outputDim].
// [Samples, timeSteps, inputDim].
var initialState = zeros(inputs.shape); // [Samples].
initialState = sum$1(initialState, [1, 2]);
initialState = expandDims$1(initialState); // [Samples, 1].
if (Array.isArray(_this5.cell.stateSize)) {
return _this5.cell.stateSize.map(function (dim) {
return dim > 1 ? tile$1(initialState, [1, dim]) : initialState;
});
} else {
return _this5.cell.stateSize > 1 ? [tile$1(initialState, [1, _this5.cell.stateSize])] : [initialState];
}
});
};
_proto.setFastWeightInitDuringBuild = function setFastWeightInitDuringBuild(value) {
_Layer.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.cell != null) {
this.cell.setFastWeightInitDuringBuild(value);
}
};
_proto.getConfig = function getConfig() {
var baseConfig = _Layer.prototype.getConfig.call(this);
var config = {
returnSequences: this.returnSequences,
returnState: this.returnState,
goBackwards: this.goBackwards,
stateful: this.stateful,
unroll: this.unroll
};
if (this.numConstants != null) {
config['numConstants'] = this.numConstants;
}
var cellConfig = this.cell.getConfig();
if (this.getClassName() === RNN.className) {
config['cell'] = {
'className': this.cell.getClassName(),
'config': cellConfig
};
} // this order is necessary, to prevent cell name from replacing layer name
return Object.assign({}, cellConfig, baseConfig, config);
}
/** @nocollapse */
;
RNN.fromConfig = function fromConfig(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var cellConfig = config['cell'];
var cell = deserialize$1(cellConfig, customObjects);
return new cls(Object.assign(config, {
cell: cell
}));
};
_createClass(RNN, [{
key: "states",
get: function get() {
if (this.states_ == null) {
var numStates = Array.isArray(this.cell.stateSize) ? this.cell.stateSize.length : 1;
var output = [];
for (var i = 0; i < numStates; ++i) {
output.push(null);
}
return output;
} else {
return this.states_;
}
},
set: function set(s) {
this.states_ = s;
}
}, {
key: "trainableWeights",
get: function get() {
if (!this.trainable) {
return [];
} // Porting Note: In TypeScript, `this` is always an instance of `Layer`.
return this.cell.trainableWeights;
}
}, {
key: "nonTrainableWeights",
get: function get() {
// Porting Note: In TypeScript, `this` is always an instance of `Layer`.
if (!this.trainable) {
return this.cell.weights;
}
return this.cell.nonTrainableWeights;
}
}]);
return RNN;
}(Layer);
/** @nocollapse */
RNN.className = 'RNN';
registerClass(RNN); // Porting Note: This is a common parent class for RNN cells. There is no
// equivalent of this in PyKeras. Having a common parent class forgoes the
// need for `has_attr(cell, ...)` checks or its TypeScript equivalent.
/**
* An RNNCell layer.
*
* @doc {heading: 'Layers', subheading: 'Classes'}
*/
var RNNCell = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(RNNCell, _Layer2);
function RNNCell() {
return _Layer2.apply(this, arguments) || this;
}
return RNNCell;
}(Layer);
var SimpleRNNCell = /*#__PURE__*/function (_RNNCell) {
_inheritsLoose(SimpleRNNCell, _RNNCell);
function SimpleRNNCell(args) {
var _this6;
_this6 = _RNNCell.call(this, args) || this;
_this6.DEFAULT_ACTIVATION = 'tanh';
_this6.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
_this6.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
_this6.DEFAULT_BIAS_INITIALIZER = 'zeros';
_this6.units = args.units;
assertPositiveInteger(_this6.units, "units");
_this6.activation = getActivation(args.activation == null ? _this6.DEFAULT_ACTIVATION : args.activation);
_this6.useBias = args.useBias == null ? true : args.useBias;
_this6.kernelInitializer = getInitializer(args.kernelInitializer || _this6.DEFAULT_KERNEL_INITIALIZER);
_this6.recurrentInitializer = getInitializer(args.recurrentInitializer || _this6.DEFAULT_RECURRENT_INITIALIZER);
_this6.biasInitializer = getInitializer(args.biasInitializer || _this6.DEFAULT_BIAS_INITIALIZER);
_this6.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this6.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this6.biasRegularizer = getRegularizer(args.biasRegularizer);
_this6.kernelConstraint = getConstraint(args.kernelConstraint);
_this6.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this6.biasConstraint = getConstraint(args.biasConstraint);
_this6.dropout = min$a([1, max$6([0, args.dropout == null ? 0 : args.dropout])]);
_this6.recurrentDropout = min$a([1, max$6([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
_this6.stateSize = _this6.units;
_this6.dropoutMask = null;
_this6.recurrentDropoutMask = null;
return _this6;
}
var _proto2 = SimpleRNNCell.prototype;
_proto2.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape); // TODO(cais): Use regularizer.
this.kernel = this.addWeight('kernel', [inputShape[inputShape.length - 1], this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
}
this.built = true;
} // Porting Note: PyKeras' equivalent of this method takes two tensor inputs:
// `inputs` and `states`. Here, the two tensors are combined into an
// `Tensor[]` Array as the first input argument.
// Similarly, PyKeras' equivalent of this method returns two values:
// `output` and `[output]`. Here the two are combined into one length-2
// `Tensor[]`, consisting of `output` repeated.
;
_proto2.call = function call(inputs, kwargs) {
var _this7 = this;
return tidy(function () {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError("SimpleRNNCell expects 2 input Tensors, got " + inputs.length + ".");
}
var prevOutput = inputs[1];
inputs = inputs[0];
var training = kwargs['training'] == null ? false : kwargs['training'];
if (0 < _this7.dropout && _this7.dropout < 1 && _this7.dropoutMask == null) {
_this7.dropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(inputs);
},
rate: _this7.dropout,
training: training
});
}
if (0 < _this7.recurrentDropout && _this7.recurrentDropout < 1 && _this7.recurrentDropoutMask == null) {
_this7.recurrentDropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(prevOutput);
},
rate: _this7.recurrentDropout,
training: training
});
}
var h;
var dpMask = _this7.dropoutMask;
var recDpMask = _this7.recurrentDropoutMask;
if (dpMask != null) {
h = dot$1(mul(inputs, dpMask), _this7.kernel.read());
} else {
h = dot$1(inputs, _this7.kernel.read());
}
if (_this7.bias != null) {
h = biasAdd(h, _this7.bias.read());
}
if (recDpMask != null) {
prevOutput = mul(prevOutput, recDpMask);
}
var output = add$1(h, dot$1(prevOutput, _this7.recurrentKernel.read()));
if (_this7.activation != null) {
output = _this7.activation.apply(output);
} // TODO(cais): Properly set learning phase on output tensor?
return [output, output];
});
};
_proto2.getConfig = function getConfig() {
var baseConfig = _RNNCell.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout
};
return Object.assign({}, baseConfig, config);
};
return SimpleRNNCell;
}(RNNCell);
/** @nocollapse */
SimpleRNNCell.className = 'SimpleRNNCell';
registerClass(SimpleRNNCell);
var SimpleRNN = /*#__PURE__*/function (_RNN) {
_inheritsLoose(SimpleRNN, _RNN);
function SimpleRNN(args) {
args.cell = new SimpleRNNCell(args);
return _RNN.call(this, args) || this; // TODO(cais): Add activityRegularizer.
}
var _proto3 = SimpleRNN.prototype;
_proto3.call = function call(inputs, kwargs) {
var _this8 = this;
return tidy(function () {
if (_this8.cell.dropoutMask != null) {
dispose(_this8.cell.dropoutMask);
_this8.cell.dropoutMask = null;
}
if (_this8.cell.recurrentDropoutMask != null) {
dispose(_this8.cell.recurrentDropoutMask);
_this8.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs['mask'];
var training = kwargs == null ? null : kwargs['training'];
var initialState = kwargs == null ? null : kwargs['initialState'];
return _RNN.prototype.call.call(_this8, inputs, {
mask: mask,
training: training,
initialState: initialState
});
});
}
/** @nocollapse */
;
SimpleRNN.fromConfig = function fromConfig(cls, config) {
return new cls(config);
};
return SimpleRNN;
}(RNN);
/** @nocollapse */
SimpleRNN.className = 'SimpleRNN';
registerClass(SimpleRNN);
var GRUCell = /*#__PURE__*/function (_RNNCell2) {
_inheritsLoose(GRUCell, _RNNCell2);
function GRUCell(args) {
var _this9;
_this9 = _RNNCell2.call(this, args) || this;
_this9.DEFAULT_ACTIVATION = 'tanh';
_this9.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
_this9.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
_this9.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
_this9.DEFAULT_BIAS_INITIALIZER = 'zeros';
if (args.resetAfter) {
throw new ValueError("GRUCell does not support reset_after parameter set to true.");
}
_this9.units = args.units;
assertPositiveInteger(_this9.units, 'units');
_this9.activation = getActivation(args.activation === undefined ? _this9.DEFAULT_ACTIVATION : args.activation);
_this9.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this9.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
_this9.useBias = args.useBias == null ? true : args.useBias;
_this9.kernelInitializer = getInitializer(args.kernelInitializer || _this9.DEFAULT_KERNEL_INITIALIZER);
_this9.recurrentInitializer = getInitializer(args.recurrentInitializer || _this9.DEFAULT_RECURRENT_INITIALIZER);
_this9.biasInitializer = getInitializer(args.biasInitializer || _this9.DEFAULT_BIAS_INITIALIZER);
_this9.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this9.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this9.biasRegularizer = getRegularizer(args.biasRegularizer);
_this9.kernelConstraint = getConstraint(args.kernelConstraint);
_this9.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this9.biasConstraint = getConstraint(args.biasConstraint);
_this9.dropout = min$a([1, max$6([0, args.dropout == null ? 0 : args.dropout])]);
_this9.recurrentDropout = min$a([1, max$6([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
_this9.implementation = args.implementation;
_this9.stateSize = _this9.units;
_this9.dropoutMask = null;
_this9.recurrentDropoutMask = null;
return _this9;
}
var _proto4 = GRUCell.prototype;
_proto4.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight('kernel', [inputDim, this.units * 3], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 3], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units * 3], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
} // Porting Notes: Unlike the PyKeras implementation, we perform slicing
// of the weights and bias in the call() method, at execution time.
this.built = true;
};
_proto4.call = function call(inputs, kwargs) {
var _this10 = this;
return tidy(function () {
inputs = inputs;
if (inputs.length !== 2) {
throw new ValueError("GRUCell expects 2 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var training = kwargs['training'] == null ? false : kwargs['training'];
var hTMinus1 = inputs[1]; // Previous memory state.
inputs = inputs[0]; // Note: For superior performance, TensorFlow.js always uses
// implementation 2, regardless of the actual value of
// config.implementation.
if (0 < _this10.dropout && _this10.dropout < 1 && _this10.dropoutMask == null) {
_this10.dropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(inputs);
},
rate: _this10.dropout,
training: training,
count: 3
});
}
if (0 < _this10.recurrentDropout && _this10.recurrentDropout < 1 && _this10.recurrentDropoutMask == null) {
_this10.recurrentDropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(hTMinus1);
},
rate: _this10.recurrentDropout,
training: training,
count: 3
});
}
var dpMask = _this10.dropoutMask;
var recDpMask = _this10.recurrentDropoutMask;
var z;
var r;
var hh;
if (0 < _this10.dropout && _this10.dropout < 1) {
inputs = mul(inputs, dpMask[0]);
}
var matrixX = dot$1(inputs, _this10.kernel.read());
if (_this10.useBias) {
matrixX = biasAdd(matrixX, _this10.bias.read());
}
if (0 < _this10.recurrentDropout && _this10.recurrentDropout < 1) {
hTMinus1 = mul(hTMinus1, recDpMask[0]);
}
var recurrentKernelValue = _this10.recurrentKernel.read();
var _tfc$split = split$1(recurrentKernelValue, [2 * _this10.units, _this10.units], recurrentKernelValue.rank - 1),
rk1 = _tfc$split[0],
rk2 = _tfc$split[1];
var matrixInner = dot$1(hTMinus1, rk1);
var _tfc$split2 = split$1(matrixX, 3, matrixX.rank - 1),
xZ = _tfc$split2[0],
xR = _tfc$split2[1],
xH = _tfc$split2[2];
var _tfc$split3 = split$1(matrixInner, 2, matrixInner.rank - 1),
recurrentZ = _tfc$split3[0],
recurrentR = _tfc$split3[1];
z = _this10.recurrentActivation.apply(add$1(xZ, recurrentZ));
r = _this10.recurrentActivation.apply(add$1(xR, recurrentR));
var recurrentH = dot$1(mul(r, hTMinus1), rk2);
hh = _this10.activation.apply(add$1(xH, recurrentH));
var h = add$1(mul(z, hTMinus1), mul(add$1(1, neg(z)), hh)); // TODO(cais): Add use_learning_phase flag properly.
return [h, h];
});
};
_proto4.getConfig = function getConfig() {
var baseConfig = _RNNCell2.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
recurrentActivation: serializeActivation(this.recurrentActivation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
implementation: this.implementation,
resetAfter: false
};
return Object.assign({}, baseConfig, config);
};
return GRUCell;
}(RNNCell);
/** @nocollapse */
GRUCell.className = 'GRUCell';
registerClass(GRUCell);
var GRU = /*#__PURE__*/function (_RNN2) {
_inheritsLoose(GRU, _RNN2);
function GRU(args) {
if (args.implementation === 0) {
console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.');
}
args.cell = new GRUCell(args);
return _RNN2.call(this, args) || this; // TODO(cais): Add activityRegularizer.
}
var _proto5 = GRU.prototype;
_proto5.call = function call(inputs, kwargs) {
var _this11 = this;
return tidy(function () {
if (_this11.cell.dropoutMask != null) {
dispose(_this11.cell.dropoutMask);
_this11.cell.dropoutMask = null;
}
if (_this11.cell.recurrentDropoutMask != null) {
dispose(_this11.cell.recurrentDropoutMask);
_this11.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs['mask'];
var training = kwargs == null ? null : kwargs['training'];
var initialState = kwargs == null ? null : kwargs['initialState'];
return _RNN2.prototype.call.call(_this11, inputs, {
mask: mask,
training: training,
initialState: initialState
});
});
}
/** @nocollapse */
;
GRU.fromConfig = function fromConfig(cls, config) {
if (config['implmentation'] === 0) {
config['implementation'] = 1;
}
return new cls(config);
};
return GRU;
}(RNN);
/** @nocollapse */
GRU.className = 'GRU';
registerClass(GRU);
var LSTMCell = /*#__PURE__*/function (_RNNCell3) {
_inheritsLoose(LSTMCell, _RNNCell3);
function LSTMCell(args) {
var _this12;
_this12 = _RNNCell3.call(this, args) || this;
_this12.DEFAULT_ACTIVATION = 'tanh';
_this12.DEFAULT_RECURRENT_ACTIVATION = 'hardSigmoid';
_this12.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
_this12.DEFAULT_RECURRENT_INITIALIZER = 'orthogonal';
_this12.DEFAULT_BIAS_INITIALIZER = 'zeros';
_this12.units = args.units;
assertPositiveInteger(_this12.units, 'units');
_this12.activation = getActivation(args.activation === undefined ? _this12.DEFAULT_ACTIVATION : args.activation);
_this12.recurrentActivation = getActivation(args.recurrentActivation === undefined ? _this12.DEFAULT_RECURRENT_ACTIVATION : args.recurrentActivation);
_this12.useBias = args.useBias == null ? true : args.useBias;
_this12.kernelInitializer = getInitializer(args.kernelInitializer || _this12.DEFAULT_KERNEL_INITIALIZER);
_this12.recurrentInitializer = getInitializer(args.recurrentInitializer || _this12.DEFAULT_RECURRENT_INITIALIZER);
_this12.biasInitializer = getInitializer(args.biasInitializer || _this12.DEFAULT_BIAS_INITIALIZER);
_this12.unitForgetBias = args.unitForgetBias;
_this12.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this12.recurrentRegularizer = getRegularizer(args.recurrentRegularizer);
_this12.biasRegularizer = getRegularizer(args.biasRegularizer);
_this12.kernelConstraint = getConstraint(args.kernelConstraint);
_this12.recurrentConstraint = getConstraint(args.recurrentConstraint);
_this12.biasConstraint = getConstraint(args.biasConstraint);
_this12.dropout = min$a([1, max$6([0, args.dropout == null ? 0 : args.dropout])]);
_this12.recurrentDropout = min$a([1, max$6([0, args.recurrentDropout == null ? 0 : args.recurrentDropout])]);
_this12.implementation = args.implementation;
_this12.stateSize = [_this12.units, _this12.units];
_this12.dropoutMask = null;
_this12.recurrentDropoutMask = null;
return _this12;
}
var _proto6 = LSTMCell.prototype;
_proto6.build = function build(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var inputDim = inputShape[inputShape.length - 1];
this.kernel = this.addWeight('kernel', [inputDim, this.units * 4], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
this.recurrentKernel = this.addWeight('recurrent_kernel', [this.units, this.units * 4], null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
var biasInitializer;
if (this.useBias) {
if (this.unitForgetBias) {
var capturedBiasInit = this.biasInitializer;
var capturedUnits = this.units;
biasInitializer = new (_a = /*#__PURE__*/function (_Initializer) {
_inheritsLoose(CustomInit, _Initializer);
function CustomInit() {
return _Initializer.apply(this, arguments) || this;
}
var _proto7 = CustomInit.prototype;
_proto7.apply = function apply(shape, dtype) {
// TODO(cais): More informative variable names?
var bI = capturedBiasInit.apply([capturedUnits]);
var bF = new Ones().apply([capturedUnits]);
var bCAndH = capturedBiasInit.apply([capturedUnits * 2]);
return concatAlongFirstAxis(concatAlongFirstAxis(bI, bF), bCAndH);
};
return CustomInit;
}(Initializer),
/** @nocollapse */
_a.className = 'CustomInit', _a)();
} else {
biasInitializer = this.biasInitializer;
}
this.bias = this.addWeight('bias', [this.units * 4], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
} else {
this.bias = null;
} // Porting Notes: Unlike the PyKeras implementation, we perform slicing
// of the weights and bias in the call() method, at execution time.
this.built = true;
};
_proto6.call = function call(inputs, kwargs) {
var _this13 = this;
return tidy(function () {
var training = kwargs['training'] == null ? false : kwargs['training'];
inputs = inputs;
if (inputs.length !== 3) {
throw new ValueError("LSTMCell expects 3 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var hTMinus1 = inputs[1]; // Previous memory state.
var cTMinus1 = inputs[2]; // Previous carry state.
inputs = inputs[0];
if (0 < _this13.dropout && _this13.dropout < 1 && _this13.dropoutMask == null) {
_this13.dropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(inputs);
},
rate: _this13.dropout,
training: training,
count: 4
});
}
if (0 < _this13.recurrentDropout && _this13.recurrentDropout < 1 && _this13.recurrentDropoutMask == null) {
_this13.recurrentDropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(hTMinus1);
},
rate: _this13.recurrentDropout,
training: training,
count: 4
});
}
var dpMask = _this13.dropoutMask;
var recDpMask = _this13.recurrentDropoutMask; // Note: For superior performance, TensorFlow.js always uses
// implementation 2 regardless of the actual value of
// config.implementation.
var i;
var f;
var c;
var o;
if (0 < _this13.dropout && _this13.dropout < 1) {
inputs = mul(inputs, dpMask[0]);
}
var z = dot$1(inputs, _this13.kernel.read());
if (0 < _this13.recurrentDropout && _this13.recurrentDropout < 1) {
hTMinus1 = mul(hTMinus1, recDpMask[0]);
}
z = add$1(z, dot$1(hTMinus1, _this13.recurrentKernel.read()));
if (_this13.useBias) {
z = biasAdd(z, _this13.bias.read());
}
var _tfc$split4 = split$1(z, 4, z.rank - 1),
z0 = _tfc$split4[0],
z1 = _tfc$split4[1],
z2 = _tfc$split4[2],
z3 = _tfc$split4[3];
i = _this13.recurrentActivation.apply(z0);
f = _this13.recurrentActivation.apply(z1);
c = add$1(mul(f, cTMinus1), mul(i, _this13.activation.apply(z2)));
o = _this13.recurrentActivation.apply(z3);
var h = mul(o, _this13.activation.apply(c)); // TODO(cais): Add use_learning_phase flag properly.
return [h, h, c];
});
};
_proto6.getConfig = function getConfig() {
var baseConfig = _RNNCell3.prototype.getConfig.call(this);
var config = {
units: this.units,
activation: serializeActivation(this.activation),
recurrentActivation: serializeActivation(this.recurrentActivation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
recurrentInitializer: serializeInitializer(this.recurrentInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
unitForgetBias: this.unitForgetBias,
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
recurrentRegularizer: serializeRegularizer(this.recurrentRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
recurrentConstraint: serializeConstraint(this.recurrentConstraint),
biasConstraint: serializeConstraint(this.biasConstraint),
dropout: this.dropout,
recurrentDropout: this.recurrentDropout,
implementation: this.implementation
};
return Object.assign({}, baseConfig, config);
};
return LSTMCell;
}(RNNCell);
/** @nocollapse */
LSTMCell.className = 'LSTMCell';
registerClass(LSTMCell);
var LSTM = /*#__PURE__*/function (_RNN3) {
_inheritsLoose(LSTM, _RNN3);
function LSTM(args) {
if (args.implementation === 0) {
console.warn('`implementation=0` has been deprecated, and now defaults to ' + '`implementation=1`. Please update your layer call.');
}
args.cell = new LSTMCell(args);
return _RNN3.call(this, args) || this; // TODO(cais): Add activityRegularizer.
}
var _proto8 = LSTM.prototype;
_proto8.call = function call(inputs, kwargs) {
var _this14 = this;
return tidy(function () {
if (_this14.cell.dropoutMask != null) {
dispose(_this14.cell.dropoutMask);
_this14.cell.dropoutMask = null;
}
if (_this14.cell.recurrentDropoutMask != null) {
dispose(_this14.cell.recurrentDropoutMask);
_this14.cell.recurrentDropoutMask = null;
}
var mask = kwargs == null ? null : kwargs['mask'];
var training = kwargs == null ? null : kwargs['training'];
var initialState = kwargs == null ? null : kwargs['initialState'];
return _RNN3.prototype.call.call(_this14, inputs, {
mask: mask,
training: training,
initialState: initialState
});
});
}
/** @nocollapse */
;
LSTM.fromConfig = function fromConfig(cls, config) {
if (config['implmentation'] === 0) {
config['implementation'] = 1;
}
return new cls(config);
};
return LSTM;
}(RNN);
/** @nocollapse */
LSTM.className = 'LSTM';
registerClass(LSTM);
var StackedRNNCells = /*#__PURE__*/function (_RNNCell4) {
_inheritsLoose(StackedRNNCells, _RNNCell4);
function StackedRNNCells(args) {
var _this15;
_this15 = _RNNCell4.call(this, args) || this;
_this15.cells = args.cells;
return _this15;
}
var _proto9 = StackedRNNCells.prototype;
_proto9.call = function call(inputs, kwargs) {
var _this16 = this;
return tidy(function () {
inputs = inputs;
var states = inputs.slice(1); // Recover per-cell states.
var nestedStates = [];
for (var _iterator3 = _createForOfIteratorHelperLoose(_this16.cells.slice().reverse()), _step3; !(_step3 = _iterator3()).done;) {
var _cell = _step3.value;
if (Array.isArray(_cell.stateSize)) {
nestedStates.push(states.splice(0, _cell.stateSize.length));
} else {
nestedStates.push(states.splice(0, 1));
}
}
nestedStates.reverse(); // Call the cells in order and store the returned states.
var newNestedStates = [];
var callInputs;
for (var i = 0; i < _this16.cells.length; ++i) {
var cell = _this16.cells[i];
states = nestedStates[i]; // TODO(cais): Take care of constants.
if (i === 0) {
callInputs = [inputs[0]].concat(states);
} else {
callInputs = [callInputs[0]].concat(states);
}
callInputs = cell.call(callInputs, kwargs);
newNestedStates.push(callInputs.slice(1));
} // Format the new states as a flat list in reverse cell order.
states = [];
for (var _iterator4 = _createForOfIteratorHelperLoose(newNestedStates.slice().reverse()), _step4; !(_step4 = _iterator4()).done;) {
var _states;
var cellStates = _step4.value;
(_states = states).push.apply(_states, cellStates);
}
return [callInputs[0]].concat(states);
});
};
_proto9.build = function build(inputShape) {
if (isArrayOfShapes(inputShape)) {
// TODO(cais): Take care of input constants.
// const constantShape = inputShape.slice(1);
inputShape = inputShape[0];
}
inputShape = inputShape;
var outputDim;
this.cells.forEach(function (cell, i) {
nameScope("RNNCell_" + i, function () {
// TODO(cais): Take care of input constants.
cell.build(inputShape);
if (Array.isArray(cell.stateSize)) {
outputDim = cell.stateSize[0];
} else {
outputDim = cell.stateSize;
}
inputShape = [inputShape[0], outputDim];
});
});
this.built = true;
};
_proto9.getConfig = function getConfig() {
var baseConfig = _RNNCell4.prototype.getConfig.call(this);
var getCellConfig = function getCellConfig(cell) {
return {
'className': cell.getClassName(),
'config': cell.getConfig()
};
};
var cellConfigs = this.cells.map(getCellConfig);
var config = {
'cells': cellConfigs
};
return Object.assign({}, baseConfig, config);
}
/** @nocollapse */
;
StackedRNNCells.fromConfig = function fromConfig(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var cells = [];
for (var _iterator5 = _createForOfIteratorHelperLoose(config['cells']), _step5; !(_step5 = _iterator5()).done;) {
var cellConfig = _step5.value;
cells.push(deserialize$1(cellConfig, customObjects));
}
return new cls({
cells: cells
});
};
/**
* Retrieve the weights of a the model.
*
* @returns A flat `Array` of `tf.Tensor`s.
*/
_proto9.getWeights = function getWeights() {
var weights = [];
for (var _iterator6 = _createForOfIteratorHelperLoose(this.cells), _step6; !(_step6 = _iterator6()).done;) {
var cell = _step6.value;
weights.push.apply(weights, cell.weights);
}
return batchGetValue(weights);
}
/**
* Set the weights of the model.
*
* @param weights An `Array` of `tf.Tensor`s with shapes and types matching
* the output of `getWeights()`.
*/
;
_proto9.setWeights = function setWeights(weights) {
var tuples = [];
for (var _iterator7 = _createForOfIteratorHelperLoose(this.cells), _step7; !(_step7 = _iterator7()).done;) {
var cell = _step7.value;
var numParams = cell.weights.length;
var inputWeights = weights.splice(numParams);
for (var i = 0; i < cell.weights.length; ++i) {
tuples.push([cell.weights[i], inputWeights[i]]);
}
}
batchSetValue(tuples);
};
_createClass(StackedRNNCells, [{
key: "stateSize",
get: function get() {
// States are a flat list in reverse order of the cell stack.
// This allows perserving the requirement `stack.statesize[0] ===
// outputDim`. E.g., states of a 2-layer LSTM would be `[h2, c2, h1, c1]`,
// assuming one LSTM has states `[h, c]`.
var stateSize = [];
for (var _iterator8 = _createForOfIteratorHelperLoose(this.cells.slice().reverse()), _step8; !(_step8 = _iterator8()).done;) {
var cell = _step8.value;
if (Array.isArray(cell.stateSize)) {
stateSize.push.apply(stateSize, cell.stateSize);
} else {
stateSize.push(cell.stateSize);
}
}
return stateSize;
}
}, {
key: "trainableWeights",
get: function get() {
if (!this.trainable) {
return [];
}
var weights = [];
for (var _iterator9 = _createForOfIteratorHelperLoose(this.cells), _step9; !(_step9 = _iterator9()).done;) {
var cell = _step9.value;
weights.push.apply(weights, cell.trainableWeights);
}
return weights;
}
}, {
key: "nonTrainableWeights",
get: function get() {
var weights = [];
for (var _iterator10 = _createForOfIteratorHelperLoose(this.cells), _step10; !(_step10 = _iterator10()).done;) {
var _cell2 = _step10.value;
weights.push.apply(weights, _cell2.nonTrainableWeights);
}
if (!this.trainable) {
var trainableWeights = [];
for (var _iterator11 = _createForOfIteratorHelperLoose(this.cells), _step11; !(_step11 = _iterator11()).done;) {
var cell = _step11.value;
trainableWeights.push.apply(trainableWeights, cell.trainableWeights);
}
return trainableWeights.concat(weights);
}
return weights;
}
}]);
return StackedRNNCells;
}(RNNCell);
/** @nocollapse */
StackedRNNCells.className = 'StackedRNNCells';
registerClass(StackedRNNCells);
function generateDropoutMask(args) {
var ones = args.ones,
rate = args.rate,
_args$training = args.training,
training = _args$training === void 0 ? false : _args$training,
_args$count = args.count,
count = _args$count === void 0 ? 1 : _args$count;
var droppedInputs = function droppedInputs() {
return dropout$1(ones(), rate);
};
var createMask = function createMask() {
return inTrainPhase(droppedInputs, ones, training);
}; // just in case count is provided with null or undefined
if (!count || count <= 1) {
return keep(createMask().clone());
}
var masks = Array(count).fill(undefined).map(createMask);
return masks.map(function (m) {
return keep(m.clone());
});
}
/**
* @license
* Copyright 2020 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var __rest = undefined && undefined.__rest || function (s, e) {
var t = {};
for (var p in s) {
if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0) t[p] = s[p];
}
if (s != null && typeof Object.getOwnPropertySymbols === "function") for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {
if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i])) t[p[i]] = s[p[i]];
}
return t;
};
var ConvRNN2DCell = /*#__PURE__*/function (_RNNCell) {
_inheritsLoose(ConvRNN2DCell, _RNNCell);
function ConvRNN2DCell() {
return _RNNCell.apply(this, arguments) || this;
}
return ConvRNN2DCell;
}(RNNCell);
/**
* Base class for convolutional-recurrent layers.
*/
var ConvRNN2D = /*#__PURE__*/function (_RNN) {
_inheritsLoose(ConvRNN2D, _RNN);
function ConvRNN2D(args) {
var _this;
if (args.unroll) {
throw new NotImplementedError('Unrolling is not possible with convolutional RNNs.');
}
if (Array.isArray(args.cell)) {
throw new NotImplementedError('It is not possible at the moment to stack convolutional cells.');
}
_this = _RNN.call(this, args) || this;
_this.inputSpec = [new InputSpec({
ndim: 5
})];
return _this;
}
var _proto = ConvRNN2D.prototype;
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
if (_this2.cell.dropoutMask != null) {
dispose(_this2.cell.dropoutMask);
_this2.cell.dropoutMask = null;
}
if (_this2.cell.recurrentDropoutMask != null) {
dispose(_this2.cell.recurrentDropoutMask);
_this2.cell.recurrentDropoutMask = null;
}
if (kwargs && kwargs['constants']) {
throw new ValueError('ConvRNN2D cell does not support constants');
}
var mask = kwargs == null ? null : kwargs['mask'];
var training = kwargs == null ? null : kwargs['training'];
var initialState = kwargs == null ? null : kwargs['initialState'];
return _RNN.prototype.call.call(_this2, inputs, {
mask: mask,
training: training,
initialState: initialState
});
});
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
var outShape = this.computeSingleOutputShape(inputShape);
if (!this.returnSequences) {
outShape = [outShape[0]].concat(outShape.slice(2));
}
if (this.returnState) {
outShape = [outShape].concat(Array(2).fill([inputShape[0]].concat(outShape.slice(-3))));
}
return outShape;
};
_proto.getInitialState = function getInitialState(inputs) {
var _this3 = this;
return tidy(function () {
var stateSize = _this3.cell.stateSize;
var inputShape = inputs.shape;
var outputShape = _this3.computeSingleOutputShape(inputShape);
var stateShape = [outputShape[0]].concat(outputShape.slice(2));
var initialState = zeros(stateShape);
if (Array.isArray(stateSize)) {
return Array(stateSize.length).fill(initialState);
}
return [initialState];
});
};
_proto.resetStates = function resetStates(states, training) {
var _this4 = this;
if (training === void 0) {
training = false;
}
tidy(function () {
if (!_this4.stateful) {
throw new AttributeError('Cannot call resetStates() on an RNN Layer that is not stateful.');
}
var inputShape = _this4.inputSpec[0].shape;
var outputShape = _this4.computeSingleOutputShape(inputShape);
var stateShape = [outputShape[0]].concat(outputShape.slice(2));
var batchSize = inputShape[0];
if (batchSize == null) {
throw new ValueError('If an RNN is stateful, it needs to know its batch size. Specify ' + 'the batch size of your input tensors: \n' + '- If using a Sequential model, specify the batch size by ' + 'passing a `batchInputShape` option to your first layer.\n' + '- If using the functional API, specify the batch size by ' + 'passing a `batchShape` option to your Input layer.');
} // Initialize state if null.
if (_this4.getStates() == null) {
if (Array.isArray(_this4.cell.stateSize)) {
_this4.states_ = _this4.cell.stateSize.map(function () {
return zeros(stateShape);
});
} else {
_this4.states_ = [zeros(stateShape)];
}
} else if (states == null) {
// Dispose old state tensors.
dispose(_this4.states_); // For stateful RNNs, fully dispose kept old states.
if (_this4.keptStates != null) {
dispose(_this4.keptStates);
_this4.keptStates = [];
}
if (Array.isArray(_this4.cell.stateSize)) {
_this4.states_ = _this4.cell.stateSize.map(function () {
return zeros(stateShape);
});
} else {
_this4.states_[0] = zeros(stateShape);
}
} else {
if (!Array.isArray(states)) {
states = [states];
}
if (states.length !== _this4.states_.length) {
throw new ValueError("Layer " + _this4.name + " expects " + _this4.states_.length + " state(s), " + ("but it received " + states.length + " state value(s). Input ") + ("received: " + states));
}
if (training) {
// Store old state tensors for complete disposal later, i.e., during
// the next no-arg call to this method. We do not dispose the old
// states immediately because that BPTT (among other things) require
// them.
_this4.keptStates.push(_this4.states_.slice());
} else {
dispose(_this4.states_);
}
for (var index = 0; index < _this4.states_.length; ++index) {
var value = states[index];
var expectedShape = stateShape;
if (!arraysEqual(value.shape, expectedShape)) {
throw new ValueError("State " + index + " is incompatible with layer " + _this4.name + ": " + ("expected shape=" + expectedShape + ", received shape=" + value.shape));
}
_this4.states_[index] = value;
}
}
_this4.states_ = _this4.states_.map(function (state) {
return keep(state.clone());
});
});
};
_proto.computeSingleOutputShape = function computeSingleOutputShape(inputShape) {
var _this$cell = this.cell,
dataFormat = _this$cell.dataFormat,
filters = _this$cell.filters,
kernelSize = _this$cell.kernelSize,
padding = _this$cell.padding,
strides = _this$cell.strides,
dilationRate = _this$cell.dilationRate;
var isChannelsFirst = dataFormat === 'channelsFirst';
var h = inputShape[isChannelsFirst ? 3 : 2];
var w = inputShape[isChannelsFirst ? 4 : 3];
var hOut = convOutputLength(h, kernelSize[0], padding, strides[0], dilationRate[0]);
var wOut = convOutputLength(w, kernelSize[1], padding, strides[1], dilationRate[1]);
var outShape = [].concat(inputShape.slice(0, 2), isChannelsFirst ? [filters, hOut, wOut] : [hOut, wOut, filters]);
return outShape;
};
return ConvRNN2D;
}(RNN);
/** @nocollapse */
ConvRNN2D.className = 'ConvRNN2D';
var ConvLSTM2DCell = /*#__PURE__*/function (_LSTMCell) {
_inheritsLoose(ConvLSTM2DCell, _LSTMCell);
function ConvLSTM2DCell(args) {
var _this5;
var filters = args.filters,
kernelSize = args.kernelSize,
strides = args.strides,
padding = args.padding,
dataFormat = args.dataFormat,
dilationRate = args.dilationRate;
_this5 = _LSTMCell.call(this, Object.assign({}, args, {
units: filters
})) || this;
_this5.filters = filters;
assertPositiveInteger(_this5.filters, 'filters');
_this5.kernelSize = normalizeArray(kernelSize, 2, 'kernelSize');
_this5.kernelSize.forEach(function (size) {
return assertPositiveInteger(size, 'kernelSize');
});
_this5.strides = normalizeArray(strides || 1, 2, 'strides');
_this5.strides.forEach(function (stride) {
return assertPositiveInteger(stride, 'strides');
});
_this5.padding = padding || 'valid';
checkPaddingMode(_this5.padding);
_this5.dataFormat = dataFormat || 'channelsLast';
checkDataFormat(_this5.dataFormat);
_this5.dilationRate = normalizeArray(dilationRate || 1, 2, 'dilationRate');
_this5.dilationRate.forEach(function (rate) {
return assertPositiveInteger(rate, 'dilationRate');
});
return _this5;
}
var _proto2 = ConvLSTM2DCell.prototype;
_proto2.build = function build(inputShape) {
var _a;
inputShape = getExactlyOneShape(inputShape);
var channelAxis = this.dataFormat === 'channelsFirst' ? 1 : inputShape.length - 1;
if (inputShape[channelAxis] == null) {
throw new ValueError("The channel dimension of the input should be defined. " + ("Found " + inputShape[channelAxis]));
}
var inputDim = inputShape[channelAxis];
var numOfKernels = 4;
var kernelShape = this.kernelSize.concat([inputDim, this.filters * numOfKernels]);
this.kernel = this.addWeight('kernel', kernelShape, null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
var recurrentKernelShape = this.kernelSize.concat([this.filters, this.filters * numOfKernels]);
this.recurrentKernel = this.addWeight('recurrent_kernel', recurrentKernelShape, null, this.recurrentInitializer, this.recurrentRegularizer, true, this.recurrentConstraint);
if (this.useBias) {
var biasInitializer;
if (this.unitForgetBias) {
var init = this.biasInitializer;
var filters = this.filters;
biasInitializer = new (_a = /*#__PURE__*/function (_Initializer) {
_inheritsLoose(CustomInit, _Initializer);
function CustomInit() {
return _Initializer.apply(this, arguments) || this;
}
var _proto3 = CustomInit.prototype;
_proto3.apply = function apply(shape, dtype) {
var biasI = init.apply([filters]);
var biasF = ones$1([filters]);
var biasCAndO = init.apply([filters * 2]);
return concatenate([biasI, biasF, biasCAndO]);
};
return CustomInit;
}(Initializer),
/** @nocollapse */
_a.className = 'CustomInit', _a)();
} else {
biasInitializer = this.biasInitializer;
}
this.bias = this.addWeight('bias', [this.filters * numOfKernels], null, biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
this.built = true;
};
_proto2.call = function call(inputs, kwargs) {
var _this6 = this;
return tidy(function () {
if (inputs.length !== 3) {
throw new ValueError("ConvLSTM2DCell expects 3 input Tensors (inputs, h, c), got " + (inputs.length + "."));
}
var training = kwargs['training'] || false;
var x = inputs[0]; // Current input
var hTMinus1 = inputs[1]; // Previous memory state.
var cTMinus1 = inputs[2]; // Previous carry state.
var numOfKernels = 4;
if (0 < _this6.dropout && _this6.dropout < 1 && _this6.dropoutMask == null) {
_this6.dropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(x);
},
rate: _this6.dropout,
training: training,
count: numOfKernels
});
}
var dropoutMask = _this6.dropoutMask;
var applyDropout = function applyDropout(x, mask, index) {
if (!mask || !mask[index]) {
return x;
}
return mul(mask[index], x);
};
var xI = applyDropout(x, dropoutMask, 0);
var xF = applyDropout(x, dropoutMask, 1);
var xC = applyDropout(x, dropoutMask, 2);
var xO = applyDropout(x, dropoutMask, 3);
if (0 < _this6.recurrentDropout && _this6.recurrentDropout < 1 && _this6.recurrentDropoutMask == null) {
_this6.recurrentDropoutMask = generateDropoutMask({
ones: function ones() {
return onesLike(hTMinus1);
},
rate: _this6.recurrentDropout,
training: training,
count: numOfKernels
});
}
var recDropoutMask = _this6.recurrentDropoutMask;
var hI = applyDropout(hTMinus1, recDropoutMask, 0);
var hF = applyDropout(hTMinus1, recDropoutMask, 1);
var hC = applyDropout(hTMinus1, recDropoutMask, 2);
var hO = applyDropout(hTMinus1, recDropoutMask, 3);
var kernelChannelAxis = 3;
var _tfc$split = split$1(_this6.kernel.read(), numOfKernels, kernelChannelAxis),
kernelI = _tfc$split[0],
kernelF = _tfc$split[1],
kernelC = _tfc$split[2],
kernelO = _tfc$split[3];
var _ref = _this6.useBias ? split$1(_this6.bias.read(), numOfKernels) : [null, null, null, null],
biasI = _ref[0],
biasF = _ref[1],
biasC = _ref[2],
biasO = _ref[3];
xI = _this6.inputConv(xI, kernelI, biasI, _this6.padding);
xF = _this6.inputConv(xF, kernelF, biasF, _this6.padding);
xC = _this6.inputConv(xC, kernelC, biasC, _this6.padding);
xO = _this6.inputConv(xO, kernelO, biasO, _this6.padding);
var _tfc$split2 = split$1(_this6.recurrentKernel.read(), numOfKernels, kernelChannelAxis),
recKernelI = _tfc$split2[0],
recKernelF = _tfc$split2[1],
recKernelC = _tfc$split2[2],
recKernelO = _tfc$split2[3];
hI = _this6.recurrentConv(hI, recKernelI);
hF = _this6.recurrentConv(hF, recKernelF);
hC = _this6.recurrentConv(hC, recKernelC);
hO = _this6.recurrentConv(hO, recKernelO);
var i = _this6.recurrentActivation.apply(add$1(xI, hI));
var f = _this6.recurrentActivation.apply(add$1(xF, hF));
var c = add$1(mul(f, cTMinus1), mul(i, _this6.activation.apply(add$1(xC, hC))));
var h = mul(_this6.recurrentActivation.apply(add$1(xO, hO)), _this6.activation.apply(c));
return [h, h, c];
});
};
_proto2.getConfig = function getConfig() {
var _a = _LSTMCell.prototype.getConfig.call(this),
_ = _a['units'],
baseConfig = __rest(_a, ['units']);
var config = {
filters: this.filters,
kernelSize: this.kernelSize,
padding: this.padding,
dataFormat: this.dataFormat,
dilationRate: this.dilationRate,
strides: this.strides
};
return Object.assign({}, baseConfig, config);
};
_proto2.inputConv = function inputConv(x, w, b, padding) {
var out = conv2d(x, w, this.strides, padding || 'valid', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC', this.dilationRate);
if (b) {
return biasAdd(out, b, this.dataFormat);
}
return out;
};
_proto2.recurrentConv = function recurrentConv(x, w) {
var strides = 1;
return conv2d(x, w, strides, 'same', this.dataFormat === 'channelsFirst' ? 'NCHW' : 'NHWC');
};
return ConvLSTM2DCell;
}(LSTMCell);
/** @nocollapse */
ConvLSTM2DCell.className = 'ConvLSTM2DCell';
registerClass(ConvLSTM2DCell);
var ConvLSTM2D = /*#__PURE__*/function (_ConvRNN2D) {
_inheritsLoose(ConvLSTM2D, _ConvRNN2D);
function ConvLSTM2D(args) {
var cell = new ConvLSTM2DCell(args);
return _ConvRNN2D.call(this, Object.assign({}, args, {
cell: cell
})) || this;
}
/** @nocollapse */
ConvLSTM2D.fromConfig = function fromConfig(cls, config) {
return new cls(config);
};
return ConvLSTM2D;
}(ConvRNN2D);
/** @nocollapse */
ConvLSTM2D.className = 'ConvLSTM2D';
registerClass(ConvLSTM2D);
var Dropout = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Dropout, _Layer);
function Dropout(args) {
var _this;
_this = _Layer.call(this, args) || this;
_this.rate = Math.max(Math.min(args.rate, 1), 0); // So that the scalar doesn't get tidied up between executions.
_this.noiseShape = args.noiseShape;
_this.seed = args.seed;
_this.supportsMasking = true;
return _this;
}
var _proto = Dropout.prototype;
_proto.getNoiseShape = function getNoiseShape(input) {
if (this.noiseShape == null) {
return this.noiseShape;
}
var inputShape = input.shape;
var noiseShape = [];
for (var i = 0; i < this.noiseShape.length; ++i) {
noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
}
return noiseShape;
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
_this2.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
if (0 < _this2.rate && _this2.rate < 1) {
var training = kwargs['training'] == null ? false : kwargs['training'];
var noiseShape = _this2.getNoiseShape(input);
var output = inTrainPhase(function () {
return dropout$1(input, _this2.rate, noiseShape, _this2.seed);
}, function () {
return input;
}, training);
return output;
}
return inputs;
});
};
_proto.getConfig = function getConfig() {
var config = {
rate: this.rate,
noiseShape: this.noiseShape,
seed: this.seed
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
_proto.dispose = function dispose() {
return _Layer.prototype.dispose.call(this);
};
return Dropout;
}(Layer);
/** @nocollapse */
Dropout.className = 'Dropout';
registerClass(Dropout);
var SpatialDropout1D = /*#__PURE__*/function (_Dropout) {
_inheritsLoose(SpatialDropout1D, _Dropout);
function SpatialDropout1D(args) {
var _this3;
_this3 = _Dropout.call(this, args) || this;
_this3.inputSpec = [{
ndim: 3
}];
return _this3;
}
var _proto2 = SpatialDropout1D.prototype;
_proto2.getNoiseShape = function getNoiseShape(input) {
var inputShape = input.shape;
return [inputShape[0], 1, inputShape[2]];
};
return SpatialDropout1D;
}(Dropout);
/** @nocollapse */
SpatialDropout1D.className = 'SpatialDropout1D';
registerClass(SpatialDropout1D);
var Dense = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(Dense, _Layer2);
function Dense(args) {
var _this4;
_this4 = _Layer2.call(this, args) || this; // Default activation: Linear (none).
_this4.activation = null;
_this4.useBias = true;
_this4.kernel = null;
_this4.bias = null;
_this4.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
_this4.DEFAULT_BIAS_INITIALIZER = 'zeros';
if (args.batchInputShape == null && args.inputShape == null && args.inputDim != null) {
// This logic is copied from Layer's constructor, since we can't
// do exactly what the Python constructor does for Dense().
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
_this4.batchInputShape = [batchSize, args.inputDim];
}
_this4.units = args.units;
assertPositiveInteger(_this4.units, 'units');
_this4.activation = getActivation(args.activation);
if (args.useBias != null) {
_this4.useBias = args.useBias;
}
_this4.kernelInitializer = getInitializer(args.kernelInitializer || _this4.DEFAULT_KERNEL_INITIALIZER);
_this4.biasInitializer = getInitializer(args.biasInitializer || _this4.DEFAULT_BIAS_INITIALIZER);
_this4.kernelConstraint = getConstraint(args.kernelConstraint);
_this4.biasConstraint = getConstraint(args.biasConstraint);
_this4.kernelRegularizer = getRegularizer(args.kernelRegularizer);
_this4.biasRegularizer = getRegularizer(args.biasRegularizer);
_this4.activityRegularizer = getRegularizer(args.activityRegularizer);
_this4.supportsMasking = true;
_this4.inputSpec = [{
minNDim: 2
}];
return _this4;
}
var _proto3 = Dense.prototype;
_proto3.build = function build(inputShape) {
var _axes;
inputShape = getExactlyOneShape(inputShape);
var inputLastDim = inputShape[inputShape.length - 1];
if (this.kernel == null) {
this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
}
this.inputSpec = [{
minNDim: 2,
axes: (_axes = {}, _axes[-1] = inputLastDim, _axes)
}];
this.built = true;
};
_proto3.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
outputShape[outputShape.length - 1] = this.units;
return outputShape;
};
_proto3.call = function call(inputs, kwargs) {
var _this5 = this;
return tidy(function () {
_this5.invokeCallHook(inputs, kwargs); // Dense layer accepts only a single input.
var input = getExactlyOneTensor(inputs);
var fusedActivationName = mapActivationToFusedKernel(_this5.activation.getClassName());
var output;
if (fusedActivationName != null) {
output = dot$1(input, _this5.kernel.read(), fusedActivationName, _this5.bias ? _this5.bias.read() : null);
} else {
output = dot$1(input, _this5.kernel.read());
if (_this5.bias != null) {
output = biasAdd(output, _this5.bias.read());
}
if (_this5.activation != null) {
output = _this5.activation.apply(output);
}
}
return output;
});
};
_proto3.getConfig = function getConfig() {
var config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
biasConstraint: serializeConstraint(this.biasConstraint)
};
var baseConfig = _Layer2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Dense;
}(Layer);
/** @nocollapse */
Dense.className = 'Dense';
registerClass(Dense);
var Flatten = /*#__PURE__*/function (_Layer3) {
_inheritsLoose(Flatten, _Layer3);
function Flatten(args) {
var _this6;
args = args || {};
_this6 = _Layer3.call(this, args) || this;
_this6.inputSpec = [{
minNDim: 3
}];
_this6.dataFormat = args.dataFormat;
return _this6;
}
var _proto4 = Flatten.prototype;
_proto4.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
for (var _iterator = _createForOfIteratorHelperLoose(inputShape.slice(1)), _step; !(_step = _iterator()).done;) {
var dim = _step.value;
if (dim == null) {
throw new ValueError("The shape of the input to \"Flatten\" is not fully defined " + ("(got " + inputShape.slice(1) + "). Make sure to pass a complete ") + "\"input_shape\" or \"batch_input_shape\" argument to the first " + "layer in your model.");
}
}
return [inputShape[0], arrayProd(inputShape, 1)];
};
_proto4.call = function call(inputs, kwargs) {
var _this7 = this;
return tidy(function () {
_this7.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
if (_this7.dataFormat === 'channelsFirst' && input.rank > 1) {
var permutation = [0];
for (var i = 2; i < input.rank; ++i) {
permutation.push(i);
}
permutation.push(1);
input = transpose(input, permutation);
}
return batchFlatten(input);
});
};
_proto4.getConfig = function getConfig() {
var config = {};
if (this.dataFormat != null) {
config['dataFormat'] = this.dataFormat;
}
var baseConfig = _Layer3.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Flatten;
}(Layer);
/** @nocollapse */
Flatten.className = 'Flatten';
registerClass(Flatten);
var Activation$1 = /*#__PURE__*/function (_Layer4) {
_inheritsLoose(Activation, _Layer4);
function Activation(args) {
var _this8;
_this8 = _Layer4.call(this, args) || this;
_this8.supportsMasking = true;
_this8.activation = getActivation(args.activation);
return _this8;
}
var _proto5 = Activation.prototype;
_proto5.call = function call(inputs, kwargs) {
var _this9 = this;
return tidy(function () {
_this9.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
return _this9.activation.apply(input);
});
};
_proto5.getConfig = function getConfig() {
var config = {
activation: serializeActivation(this.activation)
};
var baseConfig = _Layer4.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Activation;
}(Layer);
/** @nocollapse */
Activation$1.className = 'Activation';
registerClass(Activation$1);
var RepeatVector = /*#__PURE__*/function (_Layer5) {
_inheritsLoose(RepeatVector, _Layer5);
function RepeatVector(args) {
var _this10;
_this10 = _Layer5.call(this, args) || this;
_this10.n = args.n;
_this10.inputSpec = [{
ndim: 2
}];
return _this10;
}
var _proto6 = RepeatVector.prototype;
_proto6.computeOutputShape = function computeOutputShape(inputShape) {
return [inputShape[0], this.n, inputShape[1]];
};
_proto6.call = function call(inputs, kwargs) {
var _this11 = this;
return tidy(function () {
inputs = getExactlyOneTensor(inputs);
return repeat(inputs, _this11.n);
});
};
_proto6.getConfig = function getConfig() {
var config = {
n: this.n
};
var baseConfig = _Layer5.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return RepeatVector;
}(Layer);
/** @nocollapse */
RepeatVector.className = 'RepeatVector';
registerClass(RepeatVector);
var Reshape$1 = /*#__PURE__*/function (_Layer6) {
_inheritsLoose(Reshape, _Layer6);
function Reshape(args) {
var _this12;
_this12 = _Layer6.call(this, args) || this;
_this12.targetShape = args.targetShape; // Make sure that all unknown dimensions are represented as `null`.
for (var i = 0; i < _this12.targetShape.length; ++i) {
if (_this12.isUnknown(_this12.targetShape[i])) {
_this12.targetShape[i] = null;
}
}
return _this12;
}
var _proto7 = Reshape.prototype;
_proto7.isUnknown = function isUnknown(dim) {
return dim < 0 || dim == null;
}
/**
* Finds and replaces a missing dimension in output shape.
*
* This is a near direct port of the internal Numpy function
* `_fix_unknown_dimension` in `numpy/core/src/multiarray/shape.c`.
*
* @param inputShape: Original shape of array begin reshape.
* @param outputShape: Target shape of the array, with at most a single
* `null` or negative number, which indicates an underdetermined dimension
* that should be derived from `inputShape` and the known dimensions of
* `outputShape`.
* @returns: The output shape with `null` replaced with its computed value.
* @throws: ValueError: If `inputShape` and `outputShape` do not match.
*/
;
_proto7.fixUnknownDimension = function fixUnknownDimension(inputShape, outputShape) {
var errorMsg = 'Total size of new array must be unchanged.';
var finalShape = outputShape.slice();
var known = 1;
var unknown = null;
for (var i = 0; i < finalShape.length; ++i) {
var dim = finalShape[i];
if (this.isUnknown(dim)) {
if (unknown === null) {
unknown = i;
} else {
throw new ValueError('Can only specifiy one unknown dimension.');
}
} else {
known *= dim;
}
}
var originalSize = arrayProd(inputShape);
if (unknown !== null) {
if (known === 0 || originalSize % known !== 0) {
throw new ValueError(errorMsg);
}
finalShape[unknown] = originalSize / known;
} else if (originalSize !== known) {
throw new ValueError(errorMsg);
}
return finalShape;
};
_proto7.computeOutputShape = function computeOutputShape(inputShape) {
var anyUnknownDims = false;
for (var i = 0; i < inputShape.length; ++i) {
if (this.isUnknown(inputShape[i])) {
anyUnknownDims = true;
break;
}
}
if (anyUnknownDims) {
return inputShape.slice(0, 1).concat(this.targetShape);
} else {
return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
}
};
_proto7.call = function call(inputs, kwargs) {
var _this13 = this;
return tidy(function () {
_this13.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
var inputShape = input.shape;
var outputShape = inputShape.slice(0, 1).concat(_this13.fixUnknownDimension(inputShape.slice(1), _this13.targetShape));
return reshape(input, outputShape);
});
};
_proto7.getConfig = function getConfig() {
var config = {
targetShape: this.targetShape
};
var baseConfig = _Layer6.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Reshape;
}(Layer);
/** @nocollapse */
Reshape$1.className = 'Reshape';
registerClass(Reshape$1);
var Permute = /*#__PURE__*/function (_Layer7) {
_inheritsLoose(Permute, _Layer7);
function Permute(args) {
var _this14;
_this14 = _Layer7.call(this, args) || this;
if (args.dims == null) {
throw new Error('Required configuration field `dims` is missing during Permute ' + 'constructor call.');
}
if (!Array.isArray(args.dims)) {
throw new Error('Permute constructor requires `dims` to be an Array, but received ' + (args.dims + " instead."));
} // Check the validity of the permutation indices.
var expectedSortedIndices = range$1(1, args.dims.length + 1);
if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) + ' `dims` must contain consecutive integers starting from 1.');
}
_this14.dims = args.dims;
_this14.dimsIncludingBatch = [0].concat(_this14.dims);
_this14.inputSpec = [new InputSpec({
ndim: _this14.dims.length + 1
})];
return _this14;
}
var _proto8 = Permute.prototype;
_proto8.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var outputShape = inputShape.slice();
this.dims.forEach(function (dim, i) {
outputShape[i + 1] = inputShape[dim];
});
return outputShape;
};
_proto8.call = function call(inputs, kwargs) {
return transpose(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
};
_proto8.getConfig = function getConfig() {
var config = {
dims: this.dims
};
var baseConfig = _Layer7.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Permute;
}(Layer);
/** @nocollapse */
Permute.className = 'Permute';
registerClass(Permute);
var Masking = /*#__PURE__*/function (_Layer8) {
_inheritsLoose(Masking, _Layer8);
function Masking(args) {
var _this15;
_this15 = _Layer8.call(this, args == null ? {} : args) || this;
_this15.supportsMasking = true;
if (args != null) {
_this15.maskValue = args.maskValue == null ? 0 : args.maskValue;
} else {
_this15.maskValue = 0;
}
return _this15;
}
var _proto9 = Masking.prototype;
_proto9.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto9.getConfig = function getConfig() {
var baseConfig = _Layer8.prototype.getConfig.call(this);
var config = {
maskValue: this.maskValue
};
Object.assign(config, baseConfig);
return config;
};
_proto9.computeMask = function computeMask(inputs, mask) {
var input = getExactlyOneTensor(inputs);
var axis = -1;
return any(notEqual(input, this.maskValue), axis);
};
_proto9.call = function call(inputs, kwargs) {
var _this16 = this;
return tidy(function () {
_this16.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
var axis = -1;
var keepDims = true;
var booleanMask = any(notEqual(input, _this16.maskValue), axis, keepDims);
var output = mul(input, cast(booleanMask, input.dtype));
return output;
});
};
return Masking;
}(Layer);
/** @nocollapse */
Masking.className = 'Masking';
registerClass(Masking);
var Embedding = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Embedding, _Layer);
function Embedding(args) {
var _this;
_this = _Layer.call(this, args) || this;
_this.embeddings = null;
_this.DEFAULT_EMBEDDINGS_INITIALIZER = 'randomUniform';
if (args.batchInputShape == null && args.inputShape == null) {
// Porting Note: This logic is copied from Layer's constructor, since we
// can't do exactly what the Python constructor does for Embedding().
// Specifically, the super constructor can not be called after the
// mutation of the `config` argument.
var batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
if (args.inputLength == null) {
// Fix super-constructor to what it would have done if
// 'config.inputShape' were (None, )
_this.batchInputShape = [batchSize, null];
} else {
// Fix super-constructor to what it would have done if
// 'config.inputShape' were (config.inputLength, )
_this.batchInputShape = [batchSize].concat(toList(args.inputLength));
}
}
_this.inputDim = args.inputDim;
assertPositiveInteger(_this.inputDim, 'inputDim');
_this.outputDim = args.outputDim;
assertPositiveInteger(_this.outputDim, 'outputDim');
_this.embeddingsInitializer = getInitializer(args.embeddingsInitializer || _this.DEFAULT_EMBEDDINGS_INITIALIZER);
_this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
_this.activityRegularizer = getRegularizer(args.activityRegularizer);
_this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
_this.maskZero = args.maskZero;
_this.supportsMasking = args.maskZero;
_this.inputLength = args.inputLength;
return _this;
}
var _proto = Embedding.prototype;
_proto.build = function build(inputShape) {
this.embeddings = this.addWeight('embeddings', [this.inputDim, this.outputDim], this.dtype, this.embeddingsInitializer, this.embeddingsRegularizer, true, this.embeddingsConstraint);
this.built = true;
} // Override warnOnIncompatibleInputShape because an embedding layer allows
// the input to have varying ranks.
;
_proto.warnOnIncompatibleInputShape = function warnOnIncompatibleInputShape(inputShape) {};
_proto.computeMask = function computeMask(inputs, mask) {
var _this2 = this;
return tidy(function () {
if (!_this2.maskZero) {
return null;
} else {
inputs = getExactlyOneTensor(inputs);
return notEqual(inputs, zerosLike(inputs));
}
});
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (this.inputLength == null) {
return [].concat(inputShape, [this.outputDim]);
} // inputLength can be an array if input is 3D or higher.
var inLens = toList(this.inputLength);
if (inLens.length !== inputShape.length - 1) {
throw new ValueError("\"inputLength\" is " + this.inputLength + ", but received " + ("input shape has shape " + inputShape));
} else {
var i = 0;
for (var k = 0; k < inLens.length; ++k) {
var s1 = inLens[k];
var s2 = inputShape[k + 1];
if (s1 != null && s2 != null && s1 !== s2) {
throw new ValueError("\"inputLength\" is " + this.inputLength + ", but received " + ("input shape has shape " + inputShape));
} else if (s1 == null) {
inLens[i] = s2;
}
i++;
}
}
return [inputShape[0]].concat(inLens, [this.outputDim]);
};
_proto.call = function call(inputs, kwargs) {
var _this3 = this;
return tidy(function () {
_this3.invokeCallHook(inputs, kwargs); // Embedding layer accepts only a single input.
var input = getExactlyOneTensor(inputs);
if (input.dtype !== 'int32') {
input = cast$1(input, 'int32');
}
var output = gather$1(_this3.embeddings.read(), reshape(input, [input.size]));
return reshape(output, getExactlyOneShape(_this3.computeOutputShape(input.shape)));
});
};
_proto.getConfig = function getConfig() {
var config = {
inputDim: this.inputDim,
outputDim: this.outputDim,
embeddingsInitializer: serializeInitializer(this.embeddingsInitializer),
embeddingsRegularizer: serializeRegularizer(this.embeddingsRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
embeddingsConstraint: serializeConstraint(this.embeddingsConstraint),
maskZero: this.maskZero,
inputLength: this.inputLength
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Embedding;
}(Layer);
/** @nocollapse */
Embedding.className = 'Embedding';
registerClass(Embedding);
/**
* Generic Merge layer for element-wise merge functions.
*
* Used to implement `Sum`, `Average`, `Concatenate`, etc.
*/
var Merge = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Merge, _Layer);
function Merge(args) {
var _this;
_this = _Layer.call(this, args || {}) || this;
_this.supportsMasking = true;
return _this;
}
/**
* Logic for merging multiple tensors, to be overridden by subclasses.
* @param inputs
*/
var _proto = Merge.prototype;
_proto.mergeFunction = function mergeFunction(inputs) {
throw new NotImplementedError();
}
/**
* Computes the shape of the result of an elementwise operation.
*
* @param shape1: Shape of the first tensor.
* @param shape2: Shape of the second tensor.
* @returns Expected output shape when an elementwise operation is carried
* out on 2 tensors with shapes `shape1` and `shape2`.
* @throws ValueError: If `shape1` and `shape2` are not compatible for
* element-wise operations.
*/
;
_proto.computeElementwiseOpOutputShape = function computeElementwiseOpOutputShape(shape1, shape2) {
if (shape1 == null || shape2 == null) {
return null;
} else if (shape1.length < shape2.length) {
return this.computeElementwiseOpOutputShape(shape2, shape1);
} else if (shape2.length === 0) {
return shape1;
}
var outputShape = shape1.slice(0, shape1.length - shape2.length);
for (var k = 0; k < shape2.length; ++k) {
var i = shape1[shape1.length - shape2.length + k];
var j = shape2[k];
if (i == null || j == null || i < 0 || j < 0) {
outputShape.push(null);
} else if (i === 1) {
outputShape.push(j);
} else if (j === 1) {
outputShape.push(i);
} else {
if (i !== j) {
throw new ValueError('Operands could not be broadcast together with shapes ' + JSON.stringify(shape1) + ' ' + JSON.stringify(shape2));
}
outputShape.push(i);
}
}
return outputShape;
};
_proto.build = function build(inputShape) {
// Used purely for shape validation.
if (Array.isArray(inputShape) && !Array.isArray(inputShape[0])) {
// Make sure that inputShape is an Array of shape.
inputShape = [getExactlyOneShape(inputShape)];
}
inputShape = inputShape;
if (inputShape.length < 2) {
throw new ValueError('A merge layer should be called on an Array of at least 2 inputs.' + (" Got " + inputShape.length + " input(s)."));
} // Make sure that there is at most one unique batch size among the input
// shapes.
var batchSizes = [];
for (var _iterator = _createForOfIteratorHelperLoose(inputShape), _step; !(_step = _iterator()).done;) {
var _shape = _step.value;
if (_shape != null && _shape[0] !== null) {
batchSizes.push(_shape[0]);
}
}
batchSizes = unique$1(batchSizes);
if (batchSizes.length > 1) {
throw new ValueError("Can not merge tensors with different batch sizes. " + ("Got tensors with shapes: " + JSON.stringify(inputShape) + "."));
}
var outputShape = inputShape[0] == null ? null : inputShape[0].slice(1);
for (var i = 1; i < inputShape.length; ++i) {
var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
} // If the inputs have different ranks, we have to reshape them to make them
// broadcastable.
var allRanks = inputShape.map(function (shape) {
return shape.length;
});
if (inputShape.indexOf(null) === -1 && unique$1(allRanks).length === 1) {
this.reshapeRequired = false;
} else {
this.reshapeRequired = true;
}
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
inputs = inputs;
if (_this2.reshapeRequired) {
var reshapedInputs = [];
var inputDims = inputs.map(function (input) {
return input.rank;
});
if (inputDims.indexOf(null) === -1) {
// If ranks of all inputs are available, we simply expand each of them
// at axis=1 until all of them have the same rank.
var maxNDim = max$6(inputDims);
for (var _iterator2 = _createForOfIteratorHelperLoose(inputs), _step2; !(_step2 = _iterator2()).done;) {
var x = _step2.value;
var xNDim = x.rank;
for (var k = 0; k < maxNDim - xNDim; ++k) {
x = expandDims$1(x, 1);
}
reshapedInputs.push(x);
}
return _this2.mergeFunction(reshapedInputs);
} else {
// Transpose all inputs so that batch size is the last dimension.
// [batchSize, dim1, dim2, ...] -> [dim1, dim2, ..., batchSize]
var transposed = false;
for (var _iterator3 = _createForOfIteratorHelperLoose(inputs), _step3; !(_step3 = _iterator3()).done;) {
var _x = _step3.value;
var _xNDim = _x.rank;
if (_xNDim == null) {
var xShape = _x.shape;
var _batchSize = xShape[0];
var _newShape = xShape.slice(1).concat([_batchSize]);
var xTransposed = reshape(_x, [_batchSize].concat(arrayProd(xShape.slice(1))));
xTransposed = transpose(xTransposed, [1, 0]);
xTransposed = reshape(xTransposed, _newShape);
reshapedInputs.push(xTransposed);
transposed = true;
} else if (_xNDim > 1) {
var _dims = range$1(1, _xNDim).concat([0]);
reshapedInputs.push(transpose(_x, _dims));
transposed = true;
} else {
// We don't transpose inputs if they are 1D vectors or scalars.
reshapedInputs.push(_x);
}
}
var y = _this2.mergeFunction(reshapedInputs);
var yNDim = y.rank;
if (transposed) {
// If inputs have been transposed, we have to transpose the output
// too.
if (yNDim == null) {
var yShape = y.shape;
var _yNDim = yShape.length;
var batchSize = yShape[_yNDim - 1];
var newShape = [batchSize].concat(yShape.slice(0, yShape.length - 1));
y = reshape(transpose(reshape(y, [-1, batchSize]), [1, 0]), newShape);
} else if (yNDim > 1) {
var dims = [yNDim - 1].concat(range$1(0, yNDim - 1));
y = transpose(y, dims);
}
}
return y;
}
} else {
return _this2.mergeFunction(inputs);
}
});
};
_proto.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = inputShape;
var outputShape;
if (inputShape[0] == null) {
outputShape = null;
} else {
outputShape = inputShape[0].slice(1);
}
for (var i = 1; i < inputShape.length; ++i) {
var shape = inputShape[i] == null ? null : inputShape[i].slice(1);
outputShape = this.computeElementwiseOpOutputShape(outputShape, shape);
}
var batchSizes = [];
for (var _iterator4 = _createForOfIteratorHelperLoose(inputShape), _step4; !(_step4 = _iterator4()).done;) {
var _shape2 = _step4.value;
if (_shape2 != null && _shape2[0] !== null) {
batchSizes.push(_shape2[0]);
}
}
batchSizes = unique$1(batchSizes);
if (batchSizes.length === 1) {
outputShape = batchSizes.concat(outputShape);
} else {
outputShape = [null].concat(outputShape);
}
return outputShape;
};
_proto.computeMask = function computeMask(inputs, mask) {
return tidy(function () {
if (mask == null) {
return null;
}
if (!Array.isArray(mask)) {
throw new ValueError('`mask` should be an Array');
}
if (!Array.isArray(inputs)) {
throw new ValueError('`inputs` should be an Array');
}
if (mask.length !== inputs.length) {
throw new ValueError("The Array 'inputs' and 'mask' are expected to have the same " + "length, but have different lengths " + ("(" + inputs.length + " vs " + mask.length + ")"));
}
if (mask.every(function (m) {
return m == null;
})) {
return null;
}
mask = mask.map(function (m) {
return m == null ? m : expandDims(m, 0);
});
var output = mask[0];
for (var i = 1; i < mask.length - 1; ++i) {
output = logicalAnd(output, mask[i]);
}
return output;
});
};
return Merge;
}(Layer);
var Add$1 = /*#__PURE__*/function (_Merge) {
_inheritsLoose(Add, _Merge);
function Add(args) {
return _Merge.call(this, args) || this;
}
var _proto2 = Add.prototype;
_proto2.mergeFunction = function mergeFunction(inputs) {
return tidy(function () {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = add$1(output, inputs[i]);
}
return output;
});
};
return Add;
}(Merge);
/** @nocollapse */
Add$1.className = 'Add';
registerClass(Add$1);
/**
* Calculate the element-wise sum of inputs, which all have the same shape.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Add` layer, by using no input argument
* or a single configuration argument. The resultant `Add` layer can then
* be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const addLayer = tf.layers.add();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = addLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = tf.layers.add([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
* tf.layers.add([input1, input2]).print();
* // Gives [[11, 22], [33, 44]].
*
*/
function add$2(config) {
if (Array.isArray(config)) {
var layer = new Add$1({});
return layer.apply(config);
} else {
return new Add$1(config);
}
}
var Multiply$1 = /*#__PURE__*/function (_Merge2) {
_inheritsLoose(Multiply, _Merge2);
function Multiply(args) {
return _Merge2.call(this, args) || this;
}
var _proto3 = Multiply.prototype;
_proto3.mergeFunction = function mergeFunction(inputs) {
return tidy(function () {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = mul(output, inputs[i]);
}
return output;
});
};
return Multiply;
}(Merge);
/** @nocollapse */
Multiply$1.className = 'Multiply';
registerClass(Multiply$1);
/**
* Calculate the element-wise product of inputs, which all have the same shape.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Multiply` layer, by using no input argument
* or a single configuration argument. The resultant `Multiply` layer can
* then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const multiplyLayer = tf.layers.multiply();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = multiplyLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = tf.layers.multiply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
* tf.layers.multiply([input1, input2]).print();
* // Gives [[10, 40], [90, 160]].
*
*/
function multiply$1(config) {
if (Array.isArray(config)) {
var layer = new Multiply$1({});
return layer.apply(config);
} else {
return new Multiply$1(config);
}
}
var Average = /*#__PURE__*/function (_Merge3) {
_inheritsLoose(Average, _Merge3);
function Average(args) {
return _Merge3.call(this, args) || this;
}
var _proto4 = Average.prototype;
_proto4.mergeFunction = function mergeFunction(inputs) {
return tidy(function () {
var output = inputs[0].clone();
for (var i = 1; i < inputs.length; ++i) {
output = add$1(output, inputs[i]);
}
return mul(1 / inputs.length, output);
});
};
return Average;
}(Merge);
/** @nocollapse */
Average.className = 'Average';
registerClass(Average);
/**
* Calculate the element-wise arithmetic mean of inputs, which all have the same
* shape.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Average` layer, by using no input argument
* or a single configuration argument. The resultant `Average` layer can then
* be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const averageLayer = tf.layers.average();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = averageLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = tf.layers.average([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const input2 = tf.tensor2d([10, 20, 30, 40], [2, 2]);
* tf.layers.average([input1, input2]).print();
* // Gives [[5.5, 11], [16.5, 22]].
*
*/
function average(config) {
if (Array.isArray(config)) {
var layer = new Average({});
return layer.apply(config);
} else {
return new Average(config);
}
}
var Maximum$1 = /*#__PURE__*/function (_Merge4) {
_inheritsLoose(Maximum, _Merge4);
function Maximum(args) {
return _Merge4.call(this, args) || this;
}
var _proto5 = Maximum.prototype;
_proto5.mergeFunction = function mergeFunction(inputs) {
return tidy(function () {
var output = inputs[0];
for (var i = 1; i < inputs.length; ++i) {
output = maximum(output, inputs[i]);
}
return output;
});
};
return Maximum;
}(Merge);
/** @nocollapse */
Maximum$1.className = 'Maximum';
registerClass(Maximum$1);
/**
* Calculate the element-wise maximum of inputs, which all have the same shape.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Maximum` layer, by using no input argument
* or a single configuration argument. The resultant `Maximum` layer can then
* be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const maximumLayer = tf.layers.maximum();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = maximumLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = tf.layers.maximum([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
* const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
* tf.layers.maximum([input1, input2]).print();
* // Gives [[10, 20], [30, 40]].
*
*/
function maximum$1(config) {
if (Array.isArray(config)) {
var layer = new Maximum$1({});
return layer.apply(config);
} else {
return new Maximum$1(config);
}
}
var Minimum$1 = /*#__PURE__*/function (_Merge5) {
_inheritsLoose(Minimum, _Merge5);
function Minimum(args) {
return _Merge5.call(this, args) || this;
}
var _proto6 = Minimum.prototype;
_proto6.mergeFunction = function mergeFunction(inputs) {
return tidy(function () {
var output = inputs[0];
for (var i = 1; i < inputs.length; ++i) {
output = minimum(output, inputs[i]);
}
return output;
});
};
return Minimum;
}(Merge);
/** @nocollapse */
Minimum$1.className = 'Minimum';
registerClass(Minimum$1);
/**
* Calculate the element-wise minimum of inputs, which all have the same shape.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Minimum` layer, by using no input argument
* or a single configuration argument. The resultant `Minimum` layer can then
* be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const minimumLayer = tf.layers.minimum();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = minimumLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const output = tf.layers.minimum([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([1, 20, 3, 40], [2, 2]);
* const input2 = tf.tensor2d([10, 2, 30, 4], [2, 2]);
* tf.layers.minimum([input1, input2]).print();
* // Gives [[1, 2], [3, 4]].
*
*/
function minimum$1(config) {
if (Array.isArray(config)) {
var layer = new Minimum$1({});
return layer.apply(config);
} else {
return new Minimum$1(config);
}
}
var Concatenate = /*#__PURE__*/function (_Merge6) {
_inheritsLoose(Concatenate, _Merge6);
function Concatenate(args) {
var _this3;
_this3 = _Merge6.call(this, args) || this;
_this3.DEFAULT_AXIS = -1;
if (args == null) {
args = {};
}
_this3.axis = args.axis == null ? _this3.DEFAULT_AXIS : args.axis;
_this3.supportsMasking = true;
_this3.reshapeRequired = false;
return _this3;
}
var _proto7 = Concatenate.prototype;
_proto7.build = function build(inputShape) {
// Used purely for shape validation.]
if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0])) || inputShape.length === 1) {
throw new ValueError('A `Concatenate` layer should be called on a list of at least 2 ' + 'inputs');
}
inputShape = inputShape;
var allNoneShape = true;
for (var _iterator5 = _createForOfIteratorHelperLoose(inputShape), _step5; !(_step5 = _iterator5()).done;) {
var _shape3 = _step5.value;
if (_shape3 != null) {
allNoneShape = false;
break;
}
}
if (allNoneShape) {
return;
}
var shapeSet = [];
for (var i = 0; i < inputShape.length; ++i) {
var shapeWithoutConcatAxis = inputShape[i].slice();
shapeWithoutConcatAxis.splice(this.axis, 1);
var exists = false;
for (var _iterator6 = _createForOfIteratorHelperLoose(shapeSet), _step6; !(_step6 = _iterator6()).done;) {
var shape = _step6.value;
if (arraysEqual(shape, shapeWithoutConcatAxis)) {
exists = true;
break;
}
}
if (!exists) {
shapeSet.push(shapeWithoutConcatAxis);
}
}
if (shapeSet.length > 1) {
throw new ValueError('A `Concatenate` layer requires inputs with matching shapes ' + 'except for the concat axis. Got input shapes: ' + JSON.stringify(inputShape));
}
};
_proto7.mergeFunction = function mergeFunction(inputs) {
var _this4 = this;
return tidy(function () {
return concatenate(inputs, _this4.axis);
});
};
_proto7.computeOutputShape = function computeOutputShape(inputShape) {
if (!(Array.isArray(inputShape) && Array.isArray(inputShape[0]))) {
throw new ValueError('A `Concatenate` layer should be called on a list of inputs.');
}
var inputShapes = inputShape;
var outputShape = inputShapes[0].slice();
var axis = this.axis < 0 ? outputShape.length + this.axis : this.axis; // Porting Note: the line above is because TypeScript doesn't support
// negative indices.
for (var _iterator7 = _createForOfIteratorHelperLoose(inputShapes.slice(1)), _step7; !(_step7 = _iterator7()).done;) {
var shape = _step7.value;
if (outputShape[axis] == null || shape[axis] == null) {
outputShape[axis] = null;
break;
}
outputShape[axis] += shape[axis];
}
return outputShape;
};
_proto7.computeMask = function computeMask(inputs, mask) {
var _this5 = this;
if (mask == null) {
return null;
}
if (!Array.isArray(mask)) {
throw new ValueError('`mask` should be an array for Concatenate');
}
if (!Array.isArray(inputs)) {
throw new ValueError('`inputs` should be an array for Concatenate');
}
if (mask.length !== inputs.length) {
throw new ValueError("Mismatch in the length of mask (" + mask.length + ") " + ("and the legnth of inputs (" + inputs.length + ")"));
}
return tidy(function () {
var allNullMasks = true;
mask.forEach(function (m) {
if (m != null) {
allNullMasks = false;
return;
}
});
if (allNullMasks) {
return null;
}
var outputMasks = [];
for (var i = 0; i < inputs.length; ++i) {
if (mask[i] == null) {
// Input is unmasked. Append all 1's to masks.
outputMasks.push(cast(onesLike(inputs[i]), 'bool'));
} else if (mask[i].rank < inputs[i].rank) {
// Mask is smaller than the input, expand it.
outputMasks.push(expandDims(mask[i], -1));
} else {
outputMasks.push(mask[i]);
}
}
var concatenatedMasks = concat(outputMasks, _this5.axis);
return all(concatenatedMasks, -1, false);
});
};
_proto7.getConfig = function getConfig() {
var config = {
'axis': this.axis
};
var baseConfig = _Merge6.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Concatenate;
}(Merge);
/** @nocollapse */
Concatenate.className = 'Concatenate';
registerClass(Concatenate);
/**
* Concatenate an `Array` of inputs.
*
* This function can be invoked in three ways.
*
* 1. Construct an instance of `Concatenate` layer, by using no input argument
* or a single configuration argument. The resultant `Concatenate` layer can
* then be used on `tf.SymbolicTensor`s or `tf.Tensor`s. For example:
*
* ```js
* const concatLayer = tf.layers.concatenate();
*
* // The layer can be applied to inputs.
* const input1 = tf.input({shape: [2, 3]});
* const input2 = tf.input({shape: [2, 4]});
* const output = concatLayer.apply([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 7], with the first dimension as the undetermined batch
* // dimension and the last dimension as the result of concatenating the
* // last dimensions of the two inputs.
* ```
*
* 2. Invoke directly on an `Array` of `tf.SymbolicTensor`s. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.SymbolicTensor`. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 3]});
* const input2 = tf.input({shape: [2, 4]});
* const output = tf.layers.concatenate([input1, input2]);
* console.log(output.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension and the last dimension as the result of concatenating the
* // last dimensions of the two inputs.
* ```
*
* 3. Invoke directly on `tf.Tensor`s, i.e., concrete values. This constructs
* an `Layer` object internally and calls its `apply` method on the inputs,
* generating a new `tf.Tensor` as the result of the computation. For
* example:
*
* ```js
* const input1 = tf.tensor2d([[1, 2], [3, 4]], [2, 2]);
* const input2 = tf.tensor2d([[10, 20], [30, 40]], [2, 2]);
* tf.layers.concatenate([input1, input2]).print();
* // Gives [[1, 2, 10, 20], [3, 4, 30, 40]].
*
*/
function concatenate$1(config) {
if (Array.isArray(config)) {
var layer = new Concatenate({});
return layer.apply(config);
} else {
return new Concatenate(config);
}
}
/**
* Interpretable potentially negative axis index.
*
* For example, given axis = -1, and dim = 3, this function will return 2.
*
* @param axis The axis index, may be a positive, zero or negative integer.
* @param dim Total number of dimensions, a positive integer.
* @returns A non-negative axis index equivalent to the input `axis`.
*/
function interpretAxis(axis, dim) {
while (axis < 0) {
axis += dim;
}
return axis;
}
function batchDot(x, y, axes) {
if (x.shape.length > 3 || y.shape.length > 3) {
throw new NotImplementedError('batchDot is not implemented for tensors of 4D or higher rank yet');
}
assert(x.shape.length >= 2, function () {
return "batchDot requires the rank of x to be >= 2, " + ("but got " + x.shape.length);
});
assert(x.shape.length >= 2, function () {
return "batchDot requires the rank of y to be >= 2, " + ("but got " + y.shape.length);
});
if (typeof axes === 'number') {
axes = [axes, axes];
}
if (x.dtype === 'complex64' || y.dtype === 'complex64') {
throw new NotImplementedError('batchDot is not implemented for complex64-type Tensors yet.');
}
var xNDim = x.shape.length;
var yNDim = y.shape.length;
if (axes == null) {
// Behave like batchMatmul by default.
axes = [xNDim - 1, yNDim - 2];
}
var axesArray = axes;
return tidy(function () {
var diff;
if (xNDim > yNDim) {
diff = xNDim - yNDim;
var diffShape = [];
for (var i = 0; i < diff; ++i) {
diffShape.push(1);
}
y = reshape(y, y.shape.concat(diffShape));
} else if (yNDim > xNDim) {
diff = yNDim - xNDim;
var _diffShape = [];
for (var _i = 0; _i < diff; ++_i) {
_diffShape.push(1);
}
x = reshape(x, x.shape.concat(_diffShape));
} else {
diff = 0;
}
var out;
if (x.shape.length === 2 && y.shape.length === 2) {
if (axesArray[0] === axesArray[1]) {
out = sum$1(mul(x, y), axesArray[0]);
} else {
out = sum$1(mul(transpose(x, [1, 0]), y), axesArray[1]);
}
} else {
var adjX = axesArray[0] !== x.shape.length - 1;
var adjY = axesArray[1] === y.shape.length - 1;
out = matMul(x, y, adjX, adjY);
}
if (diff > 0) {
var idx;
if (xNDim > yNDim) {
idx = xNDim + yNDim - 3;
} else {
idx = xNDim - 1;
}
var squeezeAxes = [];
for (var _i2 = idx; _i2 < idx + diff; ++_i2) {
squeezeAxes.push(_i2);
}
out = squeeze(out, squeezeAxes);
}
if (out.shape.length === 1) {
out = expandDims(out, 1);
}
return out;
});
}
var Dot = /*#__PURE__*/function (_Merge7) {
_inheritsLoose(Dot, _Merge7);
function Dot(args) {
var _this6;
_this6 = _Merge7.call(this, args) || this;
_this6.axes = args.axes;
_this6.normalize = args.normalize == null ? false : args.normalize;
_this6.supportsMasking = true;
_this6.reshapeRequired = false;
return _this6;
}
var _proto8 = Dot.prototype;
_proto8.build = function build(inputShape) {
assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () {
return 'A `Dot` layer should be called on a list of exactly 2 inputs.';
});
var shape1 = inputShape[0];
var shape2 = inputShape[1];
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
}
var axes = this.interpretAxes(shape1, shape2);
if (shape1[axes[0]] !== shape2[axes[1]]) {
throw new ValueError("Dimension incompatibility: " + (shape1[axes[0]] + " !== " + shape2[axes[1]]));
}
};
_proto8.mergeFunction = function mergeFunction(inputs) {
if (inputs.length !== 2) {
throw new ValueError('A `Dot` layer must be called on exactly 2 inputs, ' + ("but received " + inputs.length + " input(s)."));
}
var x1 = inputs[0];
var x2 = inputs[1];
var axes;
if (!Array.isArray(this.axes)) {
axes = [interpretAxis(this.axes, x1.shape.length), interpretAxis(this.axes, x2.shape.length)];
} else {
axes = this.axes.map(function (axis, i) {
return interpretAxis(axis, inputs[i].shape.length);
});
}
if (this.normalize) {
x1 = l2Normalize(x1, axes[0]);
x2 = l2Normalize(x2, axes[1]);
}
return batchDot(x1, x2, axes);
};
_proto8.interpretAxes = function interpretAxes(shape1, shape2) {
var axes;
if (!Array.isArray(this.axes)) {
// `this.axes` is a single integer.
axes = [interpretAxis(this.axes, shape1.length), interpretAxis(this.axes, shape2.length)];
} else {
// `this.axes` is an Array of integers.
axes = this.axes;
}
return axes;
};
_proto8.computeOutputShape = function computeOutputShape(inputShape) {
assert(Array.isArray(inputShape) && inputShape.length === 2 && Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), function () {
return 'A `Dot` layer should be called on a list of exactly 2 inputs.';
});
var shape1 = inputShape[0].slice();
var shape2 = inputShape[1].slice();
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError('Dot layer does not support tensors of 4D or higher rank yet.');
}
var axes = this.interpretAxes(shape1, shape2);
shape1.splice(axes[0], 1);
shape2.splice(axes[1], 1);
shape2.splice(0, 1);
var outputShape = shape1.concat(shape2);
if (outputShape.length === 1) {
outputShape.push(1);
}
return outputShape;
};
_proto8.computeMask = function computeMask(inputs, mask) {
return null;
};
_proto8.getConfig = function getConfig() {
var config = {
'axes': this.axes,
'normalize': this.normalize
};
var baseConfig = _Merge7.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Dot;
}(Merge);
/** @nocollapse */
Dot.className = 'Dot';
registerClass(Dot); // TODO(cais): Add functional interfaces for the merge layers.
var GaussianNoise = /*#__PURE__*/function (_Layer) {
_inheritsLoose(GaussianNoise, _Layer);
function GaussianNoise(args) {
var _this;
_this = _Layer.call(this, args) || this;
_this.supportsMasking = true;
_this.stddev = args.stddev;
return _this;
}
var _proto = GaussianNoise.prototype;
_proto.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto.getConfig = function getConfig() {
var baseConfig = _Layer.prototype.getConfig.call(this);
var config = {
stddev: this.stddev
};
Object.assign(config, baseConfig);
return config;
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
_this2.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
var noised = function noised() {
return add$1(randomNormal$1(input.shape, 0, _this2.stddev), input);
};
var output = inTrainPhase(noised, function () {
return input;
}, kwargs['training'] || false);
return output;
});
};
return GaussianNoise;
}(Layer);
/** @nocollapse */
GaussianNoise.className = 'GaussianNoise';
registerClass(GaussianNoise);
var GaussianDropout = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(GaussianDropout, _Layer2);
function GaussianDropout(args) {
var _this3;
_this3 = _Layer2.call(this, args) || this;
_this3.supportsMasking = true;
_this3.rate = args.rate;
return _this3;
}
var _proto2 = GaussianDropout.prototype;
_proto2.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto2.getConfig = function getConfig() {
var baseConfig = _Layer2.prototype.getConfig.call(this);
var config = {
rate: this.rate
};
Object.assign(config, baseConfig);
return config;
};
_proto2.call = function call(inputs, kwargs) {
var _this4 = this;
return tidy(function () {
_this4.invokeCallHook(inputs, kwargs);
var input = getExactlyOneTensor(inputs);
if (_this4.rate > 0 && _this4.rate < 1) {
var noised = function noised() {
var stddev = Math.sqrt(_this4.rate / (1 - _this4.rate));
return mul(input, randomNormal$1(input.shape, 1, stddev));
};
return inTrainPhase(noised, function () {
return input;
}, kwargs['training'] || false);
}
return input;
});
};
return GaussianDropout;
}(Layer);
/** @nocollapse */
GaussianDropout.className = 'GaussianDropout';
registerClass(GaussianDropout);
/**
* Applies Alpha Dropout to the input.
*
* As it is a regularization layer, it is only active at training time.
*
* Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
* to their original values, in order to ensure the self-normalizing property
* even after this dropout.
* Alpha Dropout fits well to Scaled Exponential Linear Units
* by randomly setting activations to the negative saturation value.
*
* Arguments:
* - `rate`: float, drop probability (as with `Dropout`).
* The multiplicative noise will have
* standard deviation `sqrt(rate / (1 - rate))`.
* - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
* shape for randomly generated keep/drop flags.
*
* Input shape:
* Arbitrary. Use the keyword argument `inputShape`
* (tuple of integers, does not include the samples axis)
* when using this layer as the first layer in a model.
*
* Output shape:
* Same shape as input.
*
* References:
* - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
*/
var AlphaDropout = /*#__PURE__*/function (_Layer3) {
_inheritsLoose(AlphaDropout, _Layer3);
function AlphaDropout(args) {
var _this5;
_this5 = _Layer3.call(this, args) || this;
_this5.supportsMasking = true;
_this5.rate = args.rate;
_this5.noiseShape = args.noiseShape;
return _this5;
}
var _proto3 = AlphaDropout.prototype;
_proto3._getNoiseShape = function _getNoiseShape(inputs) {
return this.noiseShape || getExactlyOneTensor(inputs).shape;
};
_proto3.computeOutputShape = function computeOutputShape(inputShape) {
return inputShape;
};
_proto3.getConfig = function getConfig() {
var baseConfig = _Layer3.prototype.getConfig.call(this);
var config = {
rate: this.rate
};
Object.assign(config, baseConfig);
return config;
};
_proto3.call = function call(inputs, kwargs) {
var _this6 = this;
return tidy(function () {
if (_this6.rate < 1 && _this6.rate > 0) {
var noiseShape = _this6._getNoiseShape(inputs);
var droppedInputs = function droppedInputs() {
var input = getExactlyOneTensor(inputs);
var alpha = 1.6732632423543772848170429916717;
var scale = 1.0507009873554804934193349852946;
var alphaP = -alpha * scale;
var keptIdx = greaterEqual(randomUniform(noiseShape), _this6.rate);
keptIdx = cast$1(keptIdx, 'float32'); // get default dtype.
// Get affine transformation params.
var a = Math.pow((1 - _this6.rate) * (1 + _this6.rate * Math.pow(alphaP, 2)), -0.5);
var b = -a * alphaP * _this6.rate; // Apply mask.
var x = add$1(mul(input, keptIdx), mul(add$1(keptIdx, -1), alphaP));
return add$1(mul(x, a), b);
};
return inTrainPhase(droppedInputs, function () {
return getExactlyOneTensor(inputs);
}, kwargs['training'] || false);
}
return inputs;
});
};
return AlphaDropout;
}(Layer);
/** @nocollapse */
AlphaDropout.className = 'AlphaDropout';
registerClass(AlphaDropout);
/**
* Applies batch normalization on x given mean, var, beta and gamma.
*
* I.e. returns:
* `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
*
* @param x Input tensor.
* @param mean Mean of batch.
* @param variance Variance of batch.
* @param beta Tensor with which to center the input.
* @param gamma Tensor by which to scale the input.
* @param epsilon Fuzz factor.
* @returns The result of the batch normalization.
*/
function batchNormalization(x, mean, variance, beta, gamma, epsilon) {
if (epsilon === void 0) {
epsilon = 1e-3;
}
var out;
if (x.rank === 2) {
out = batchNorm2d(x, mean, variance, beta, gamma, epsilon);
} else if (x.rank === 3) {
// TODO(cais): Check rank; give proper error message.
out = batchNorm3d(x, mean, variance, beta, gamma, epsilon);
} else if (x.rank === 4) {
out = batchNorm4d(x, mean, variance, beta, gamma, epsilon);
} else {
throw new NotImplementedError("batchNormalization is not implemented for array of rank " + x.rank + " " + "yet");
}
return out;
}
/**
* Non-broadcasting batch normalization for use in training (not inference).
*
* The input is normalized to zero mean and unit variance along the
* `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
* The result of that is returned as the first element
* of the returned `Array`. The other two elements are the mean and variance,
* respectively.
*
* @param x Input tensor to be normalized.
* @param gamma Tensor by which to scale the input.
* @param beta Tensor by which to center the input.
* @param reductionAxes Axes over which to normalize.
* @param epsilon Fuzz factor.
* @returns An `Array` of three `Tensors`:
* [normalized tensor, mean of input, variance of input].
*/
function regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) {
if (epsilon === void 0) {
epsilon = 1e-3;
}
return tidy(function () {
var meanAndVariance = moments(x, reductionAxes);
var mean = meanAndVariance.mean;
var variance = meanAndVariance.variance;
var normed = batchNormalization(x, mean, variance, beta, gamma, epsilon);
return [normed, mean, variance];
});
}
/**
* Broadcasting batch normalization for use in training (not inference).
*
* The input is normalized to zero mean and unit variance along the
* `reductionAxes`, followed by scaling with `gamma` and shifted by `beta`.
* The result of that is returned as the first element
* of the returned `Array`. The other two elements are the mean and variance,
* respectively.
*
* @param x Input tensor to be normalized.
* @param gamma Tensor by which to scale the input.
* @param beta Tensor by which to center the input.
* @param reductionAxes Axes over which to normalize.
* @param epsilon Fuzz factor.
* @returns An `Array` of three `Tensors`:
* [normalized tensor, mean of input, variance of input].
*/
function broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) {
if (epsilon === void 0) {
epsilon = 1e-3;
}
return tidy(function () {
var meanAndVariance = moments(x, reductionAxes);
var mean = meanAndVariance.mean;
var variance = meanAndVariance.variance;
var targetShape = [];
for (var _iterator = _createForOfIteratorHelperLoose(range$1(0, x.rank)), _step; !(_step = _iterator()).done;) {
var axis = _step.value;
if (reductionAxes.indexOf(axis) !== -1) {
targetShape.push(1);
} else {
targetShape.push(x.shape[axis]);
}
}
var broadcastMean = reshape(mean, targetShape);
var broadcastVariance = reshape(variance, targetShape);
var broadcastGamma = gamma == null ? null : reshape(gamma, targetShape);
var broadcastBeta = beta == null ? null : reshape(beta, targetShape);
var normed = batchNormalization(x, broadcastMean, broadcastVariance, broadcastBeta, broadcastGamma, epsilon);
return [normed, mean, variance];
});
}
/**
* Batch normalization for use in training (not inference).
*
* @param x Input tensor to be normalized.
* @param gamma Tensor by which to scale the input.
* @param beta Tensor by which to center the input.
* @param reductionAxes Axes over which to normalize.
* @param epsilon Fuzz factor.
* @returns An `Array` of three `Tensors`:
* [normalized tensor, mean of input, variance of input].
*/
function normalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon) {
if (epsilon === void 0) {
epsilon = 1e-3;
}
if (arraysEqual(reductionAxes.slice().sort(), range$1(0, x.rank - 1))) {
return regularNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
} else {
return broadcastNormalizeBatchInTraining(x, gamma, beta, reductionAxes, epsilon);
}
}
var BatchNormalization = /*#__PURE__*/function (_Layer) {
_inheritsLoose(BatchNormalization, _Layer);
function BatchNormalization(args) {
var _this;
if (args == null) {
args = {};
}
_this = _Layer.call(this, args) || this;
_this.supportsMasking = true;
_this.axis = args.axis == null ? -1 : args.axis;
_this.momentum = args.momentum == null ? 0.99 : args.momentum;
_this.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
_this.center = args.center == null ? true : args.center;
_this.scale = args.scale == null ? true : args.scale;
_this.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
_this.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
_this.movingMeanInitializer = getInitializer(args.movingMeanInitializer || 'zeros');
_this.movingVarianceInitializer = getInitializer(args.movingVarianceInitializer || 'ones');
_this.betaConstraint = getConstraint(args.betaConstraint);
_this.gammaConstraint = getConstraint(args.gammaConstraint);
_this.betaRegularizer = getRegularizer(args.betaRegularizer);
_this.gammaRegularizer = getRegularizer(args.gammaRegularizer);
return _this;
}
var _proto = BatchNormalization.prototype;
_proto.build = function build(inputShape) {
var _axes;
inputShape = getExactlyOneShape(inputShape);
var axis = this.axis >= 0 ? this.axis : this.axis + inputShape.length;
var dim = inputShape[axis];
if (dim == null) {
throw new ValueError("Axis " + axis + " of input tensor should have a defined dimension but " + "the layer received an input with shape " + (JSON.stringify(inputShape) + "."));
}
this.inputSpec = [new InputSpec({
ndim: inputShape.length,
axes: (_axes = {}, _axes[axis] = dim, _axes)
})];
var shape = [dim];
if (this.scale) {
this.gamma = this.addWeight('gamma', shape, null, this.gammaInitializer, this.gammaRegularizer, true, this.gammaConstraint);
}
if (this.center) {
this.beta = this.addWeight('beta', shape, null, this.betaInitializer, this.betaRegularizer, true, this.betaConstraint);
}
this.movingMean = this.addWeight('moving_mean', shape, null, this.movingMeanInitializer, null, false);
this.movingVariance = this.addWeight('moving_variance', shape, null, this.movingVarianceInitializer, null, false);
this.built = true;
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
var training = kwargs['training'] == null ? false : kwargs['training'];
var input = getExactlyOneTensor(inputs);
var inputShape = input.shape;
var ndim = inputShape.length;
var reductionAxes = range$1(0, ndim);
var axis = _this2.axis >= 0 ? _this2.axis : _this2.axis + ndim;
reductionAxes.splice(axis, 1);
var broadcastShape = pyListRepeat(1, ndim);
broadcastShape[axis] = inputShape[axis];
var sortedReductionAxes = reductionAxes.slice();
sortedReductionAxes.sort();
var needsBroadcasting = !arraysEqual(sortedReductionAxes, range$1(0, ndim).slice(0, ndim - 1));
var normalizeInference = function normalizeInference() {
if (needsBroadcasting) {
var broadcastMovingMean = reshape(_this2.movingMean.read(), broadcastShape);
var broadcastMovingVariance = reshape(_this2.movingVariance.read(), broadcastShape);
var broadcastBeta = _this2.center ? reshape(_this2.beta.read(), broadcastShape) : null;
var broadcastGamma = _this2.scale ? reshape(_this2.gamma.read(), broadcastShape) : null;
return batchNormalization(input, broadcastMovingMean, broadcastMovingVariance, broadcastBeta, broadcastGamma, _this2.epsilon);
} else {
return batchNormalization(input, _this2.movingMean.read(), _this2.movingVariance.read(), _this2.beta == null ? null : _this2.beta.read(), _this2.gamma == null ? null : _this2.gamma.read(), _this2.epsilon);
}
};
if (!training) {
return normalizeInference();
}
var _normalizeBatchInTrai = normalizeBatchInTraining(input, _this2.gamma.read(), _this2.beta.read(), reductionAxes, _this2.epsilon),
normedTraining = _normalizeBatchInTrai[0],
mean = _normalizeBatchInTrai[1],
variance = _normalizeBatchInTrai[2];
var doMovingAverage = function doMovingAverage(variable, value, momentum) {
tidy(function () {
var decay = 1 - momentum;
var origValue = variable.read();
var updateDelta = mul(sub(origValue, value), decay);
variable.write(sub(origValue, updateDelta));
});
}; // Perform updates to moving mean and moving variance for training.
// Porting Note: In PyKeras, these updates to `movingMean` and
// `movingAverage` are done as a deferred Graph, added to the `Layer`'s
// `update`s using the `add_update()` method. Here we do it imperatively
// and encapsulate the updates in a function that is invoked
// immediately.
var updateMovingMeanAndVariance = function updateMovingMeanAndVariance() {
doMovingAverage(_this2.movingMean, mean, _this2.momentum);
doMovingAverage(_this2.movingVariance, variance, _this2.momentum);
};
updateMovingMeanAndVariance();
return normedTraining;
});
};
_proto.getConfig = function getConfig() {
var config = {
axis: this.axis,
momentum: this.momentum,
epsilon: this.epsilon,
center: this.center,
scale: this.scale,
betaInitializer: serializeInitializer(this.betaInitializer),
gammaInitializer: serializeInitializer(this.gammaInitializer),
movingMeanInitializer: serializeInitializer(this.movingMeanInitializer),
movingVarianceInitializer: serializeInitializer(this.movingVarianceInitializer),
betaRegularizer: serializeRegularizer(this.betaRegularizer),
gammaRegularizer: serializeRegularizer(this.gammaRegularizer),
betaConstraint: serializeConstraint(this.betaConstraint),
gammaConstraint: serializeConstraint(this.gammaConstraint)
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return BatchNormalization;
}(Layer);
/** @nocollapse */
BatchNormalization.className = 'BatchNormalization';
registerClass(BatchNormalization);
var LayerNormalization = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(LayerNormalization, _Layer2);
function LayerNormalization(args) {
var _this3;
if (args == null) {
args = {};
}
_this3 = _Layer2.call(this, args) || this;
_this3.axis = args.axis == null ? -1 : args.axis;
if (typeof _this3.axis === 'number') {
if (!Number.isInteger(_this3.axis)) {
throw new Error("Expected axis to be an integer, but received " + _this3.axis);
}
} else if (Array.isArray(_this3.axis)) {
for (var _iterator2 = _createForOfIteratorHelperLoose(_this3.axis), _step2; !(_step2 = _iterator2()).done;) {
var axis = _step2.value;
if (!Number.isInteger(axis)) {
throw new Error("Expected axis to be an array of integers, " + ("but received " + JSON.stringify(_this3.axis)));
}
}
} else {
throw new Error("Expected axis to be an integer or an array of integers, " + ("but received " + JSON.stringify(_this3.axis)));
}
_this3.epsilon = args.epsilon == null ? 1e-3 : args.epsilon;
_this3.center = args.center == null ? true : args.center;
_this3.scale = args.scale == null ? true : args.scale;
_this3.betaInitializer = getInitializer(args.betaInitializer || 'zeros');
_this3.gammaInitializer = getInitializer(args.gammaInitializer || 'ones');
_this3.betaRegularizer = getRegularizer(args.betaRegularizer);
_this3.gammaRegularizer = getRegularizer(args.gammaRegularizer);
_this3.supportsMasking = true;
return _this3;
}
var _proto2 = LayerNormalization.prototype;
_proto2.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var nDims = inputShape.length; // Convert axis to array and resolve negatives.
if (typeof this.axis === 'number') {
this.axis = [this.axis];
}
for (var i = 0; i < this.axis.length; ++i) {
if (this.axis[i] < 0) {
this.axis[i] += nDims;
}
} // Further validate axes.
for (var _iterator3 = _createForOfIteratorHelperLoose(this.axis), _step3; !(_step3 = _iterator3()).done;) {
var axis = _step3.value;
if (axis < 0 || axis >= nDims) {
throw new Error("Invalid axis: " + axis);
}
}
if (this.axis.length !== unique$1(this.axis).length) {
throw new Error("Found duplicate axes in: " + this.axis);
}
var paramShape = this.axis.map(function (axis) {
return inputShape[axis];
});
var trainable = true;
if (this.scale) {
this.gamma = this.addWeight('gamma', paramShape, 'float32', this.gammaInitializer, this.gammaRegularizer, trainable);
} else {
this.gamma = null;
}
if (this.center) {
this.beta = this.addWeight('beta', paramShape, 'float32', this.betaInitializer, this.betaRegularizer, trainable);
} else {
this.beta = null;
}
this.built = true;
};
_proto2.call = function call(inputs, kwargs) {
var _this4 = this;
var input = getExactlyOneTensor(inputs);
var inputShape = input.shape;
var nDims = inputShape.length;
return tidy(function () {
var keepDims = true;
var _moments = moments(input, _this4.axis, keepDims),
mean = _moments.mean,
variance = _moments.variance;
var broadcastShape = pyListRepeat(1, nDims);
for (var _iterator4 = _createForOfIteratorHelperLoose(_this4.axis), _step4; !(_step4 = _iterator4()).done;) {
var dim = _step4.value;
broadcastShape[dim] = inputShape[dim];
}
var broadcast = function broadcast(v) {
if (v != null && v.shape.length !== nDims && _this4.axis !== [nDims - 1]) {
return reshape(v, broadcastShape);
} else {
return v;
}
};
var scale = broadcast(_this4.gamma.read());
var offset = broadcast(_this4.beta.read()); // TODO(https://github.com/tensorflow/tfjs/issues/2120): The tiling below
// is a workaround for the limitation of core's batchNormalization?d don't
// support broadcasting in their gradients. In addition, the tiling is
// necessary to ensure correctness on the browser CPU backend regardless
// of forward or backward computation. Remove this workaround once the
// limitation is addressed. See .
var momentsTiling = [];
var scaleOffsetTiling = [];
for (var i = 0; i < nDims; ++i) {
if (_this4.axis.indexOf(i) !== -1) {
momentsTiling.push(inputShape[i]);
scaleOffsetTiling.push(1);
} else {
momentsTiling.push(1);
scaleOffsetTiling.push(inputShape[i]);
}
}
mean = tile(mean, momentsTiling);
variance = tile(variance, momentsTiling);
scale = tile(scale, scaleOffsetTiling);
offset = tile(offset, scaleOffsetTiling);
return batchNormalization(input, mean, variance, offset, scale, _this4.epsilon);
});
};
_proto2.getConfig = function getConfig() {
var config = {
axis: this.axis,
epsilon: this.epsilon,
center: this.center,
scale: this.scale,
betaInitializer: serializeInitializer(this.betaInitializer),
gammaInitializer: serializeInitializer(this.gammaInitializer),
betaRegularizer: serializeRegularizer(this.betaRegularizer),
gammaRegularizer: serializeRegularizer(this.gammaRegularizer)
};
var baseConfig = _Layer2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return LayerNormalization;
}(Layer);
/** @nocollapse */
LayerNormalization.className = 'LayerNormalization';
registerClass(LayerNormalization);
/**
* Pads the middle dimension of a 3D tensor.
*
* @param x Input `tf.Tensor` to be padded.
* @param padding `Array` of 2 integers, how many zeros to add at the start and
* end of the middle dimension (i.e., dimension 1).
* @return A padded 3D `tf.Tensor`.
*/
function temporalPadding(x, padding) {
return tidy(function () {
if (x.rank !== 3) {
throw new ValueError("temporalPadding expects input tensor to be 3-D, but received a " + (x.rank + "-D tensor."));
}
if (padding == null) {
padding = [1, 1];
}
if (padding.length !== 2) {
throw new ValueError("temporalPadding expects input padding pattern to be a length-2 " + ("array, but received a length-" + padding.length + " array."));
}
var pattern = [[0, 0], padding, [0, 0]];
return pad(x, pattern);
});
}
/**
* Pads the 2nd and 3rd dimensions of a 4D tensor.
*
* @param x Input `tf.Tensor` to be padded.
* @param padding `Array` of two `Array`s, each of which is an `Array` of two
* integers. The amount of padding at the beginning and end of the 2nd and 3rd
* dimensions, respectively.
* @param dataFormat 'channelsLast' (default) or 'channelsFirst'.
* @return Padded 4D `tf.Tensor`.
*/
function spatial2dPadding(x, padding, dataFormat) {
return tidy(function () {
if (x.rank !== 4) {
throw new ValueError("temporalPadding expects input tensor to be 4-D, but received a " + (x.rank + "-D tensor."));
}
if (padding == null) {
padding = [[1, 1], [1, 1]];
}
if (padding.length !== 2 || padding[0].length !== 2 || padding[1].length !== 2) {
throw new ValueError('spatial2dPadding expects `padding` to be an Array of two Arrays, ' + 'each of which is an Array of two integers.');
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (dataFormat !== 'channelsLast' && dataFormat !== 'channelsFirst') {
throw new ValueError("Unknown data format: " + dataFormat + ". " + "Supported data formats are 'channelsLast' and 'channelsFirst.");
}
var pattern;
if (dataFormat === 'channelsFirst') {
pattern = [[0, 0], [0, 0], padding[0], padding[1]];
} else {
pattern = [[0, 0], padding[0], padding[1], [0, 0]];
}
return pad(x, pattern);
});
}
var ZeroPadding2D = /*#__PURE__*/function (_Layer) {
_inheritsLoose(ZeroPadding2D, _Layer);
function ZeroPadding2D(args) {
var _this;
if (args == null) {
args = {};
}
_this = _Layer.call(this, args) || this;
_this.dataFormat = args.dataFormat == null ? imageDataFormat() : args.dataFormat; // TODO(cais): Maybe refactor the following logic surrounding `padding`
// into a helper method.
if (args.padding == null) {
_this.padding = [[1, 1], [1, 1]];
} else if (typeof args.padding === 'number') {
_this.padding = [[args.padding, args.padding], [args.padding, args.padding]];
} else {
args.padding = args.padding;
if (args.padding.length !== 2) {
throw new ValueError("ZeroPadding2D expects padding to be a length-2 array, but " + ("received a length-" + args.padding.length + " array."));
}
var heightPadding;
var widthPadding;
if (typeof args.padding[0] === 'number') {
heightPadding = [args.padding[0], args.padding[0]];
widthPadding = [args.padding[1], args.padding[1]];
} else {
args.padding = args.padding;
if (args.padding[0].length !== 2) {
throw new ValueError("ZeroPadding2D expects height padding to be a length-2 array, " + ("but received a length-" + args.padding[0].length + " array."));
}
heightPadding = args.padding[0];
if (args.padding[1].length !== 2) {
throw new ValueError("ZeroPadding2D expects width padding to be a length-2 array, " + ("but received a length-" + args.padding[1].length + " array."));
}
widthPadding = args.padding[1];
}
_this.padding = [heightPadding, widthPadding];
}
_this.inputSpec = [new InputSpec({
ndim: 4
})];
return _this;
}
var _proto = ZeroPadding2D.prototype;
_proto.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows;
var cols;
if (this.dataFormat === 'channelsFirst') {
if (inputShape[2] != null && inputShape[2] >= 0) {
rows = inputShape[2] + this.padding[0][0] + this.padding[0][1];
} else {
rows = null;
}
if (inputShape[3] != null && inputShape[3] >= 0) {
cols = inputShape[3] + this.padding[1][0] + this.padding[1][1];
} else {
cols = null;
}
return [inputShape[0], inputShape[1], rows, cols];
} else {
if (inputShape[1] != null && inputShape[1] >= 0) {
rows = inputShape[1] + this.padding[0][0] + this.padding[0][1];
} else {
rows = null;
}
if (inputShape[2] != null && inputShape[2] >= 0) {
cols = inputShape[2] + this.padding[1][0] + this.padding[1][1];
} else {
cols = null;
}
return [inputShape[0], rows, cols, inputShape[3]];
}
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
return spatial2dPadding(getExactlyOneTensor(inputs), _this2.padding, _this2.dataFormat);
});
};
_proto.getConfig = function getConfig() {
var config = {
padding: this.padding,
dataFormat: this.dataFormat
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return ZeroPadding2D;
}(Layer);
/** @nocollapse */
ZeroPadding2D.className = 'ZeroPadding2D';
registerClass(ZeroPadding2D);
/**
* 2D pooling.
* @param x
* @param poolSize
* @param stridesdes strides. Defaults to [1, 1].
* @param padding padding. Defaults to 'valid'.
* @param dataFormat data format. Defaults to 'channelsLast'.
* @param poolMode Mode of pooling. Defaults to 'max'.
* @returns Result of the 2D pooling.
*/
function pool2d(x, poolSize, strides, padding, dataFormat, poolMode) {
return tidy(function () {
checkDataFormat(dataFormat);
checkPoolMode(poolMode);
checkPaddingMode(padding);
if (strides == null) {
strides = [1, 1];
}
if (padding == null) {
padding = 'valid';
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (poolMode == null) {
poolMode = 'max';
} // TODO(cais): Remove the preprocessing step once deeplearn.js supports
// dataFormat as an input argument.
x = preprocessConv2DInput(x, dataFormat); // x is NHWC after preprocessing.
var y;
var paddingString = padding === 'same' ? 'same' : 'valid';
if (poolMode === 'max') {
// TODO(cais): Rank check?
y = maxPool(x, poolSize, strides, paddingString);
} else {
// 'avg'
// TODO(cais): Check the dtype and rank of x and give clear error message
// if those are incorrect.
y = avgPool( // TODO(cais): Rank check?
x, poolSize, strides, paddingString);
}
if (dataFormat === 'channelsFirst') {
y = transpose(y, [0, 3, 1, 2]); // NHWC -> NCHW.
}
return y;
});
}
/**
* 3D pooling.
* @param x
* @param poolSize. Default to [1, 1, 1].
* @param strides strides. Defaults to [1, 1, 1].
* @param padding padding. Defaults to 'valid'.
* @param dataFormat data format. Defaults to 'channelsLast'.
* @param poolMode Mode of pooling. Defaults to 'max'.
* @returns Result of the 3D pooling.
*/
function pool3d(x, poolSize, strides, padding, dataFormat, poolMode) {
return tidy(function () {
checkDataFormat(dataFormat);
checkPoolMode(poolMode);
checkPaddingMode(padding);
if (strides == null) {
strides = [1, 1, 1];
}
if (padding == null) {
padding = 'valid';
}
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
if (poolMode == null) {
poolMode = 'max';
} // x is NDHWC after preprocessing.
x = preprocessConv3DInput(x, dataFormat);
var y;
var paddingString = padding === 'same' ? 'same' : 'valid';
if (poolMode === 'max') {
y = maxPool3d(x, poolSize, strides, paddingString);
} else {
// 'avg'
y = avgPool3d(x, poolSize, strides, paddingString);
}
if (dataFormat === 'channelsFirst') {
y = transpose(y, [0, 4, 1, 2, 3]); // NDHWC -> NCDHW.
}
return y;
});
}
/**
* Abstract class for different pooling 1D layers.
*/
var Pooling1D = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Pooling1D, _Layer);
/**
*
* @param args Parameters for the Pooling layer.
*
* config.poolSize defaults to 2.
*/
function Pooling1D(args) {
var _this;
if (args.poolSize == null) {
args.poolSize = 2;
}
_this = _Layer.call(this, args) || this;
if (typeof args.poolSize === 'number') {
_this.poolSize = [args.poolSize];
} else if (Array.isArray(args.poolSize) && args.poolSize.length === 1 && typeof args.poolSize[0] === 'number') {
_this.poolSize = args.poolSize;
} else {
throw new ValueError("poolSize for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + ("" + JSON.stringify(args.poolSize)));
}
assertPositiveInteger(_this.poolSize, 'poolSize');
if (args.strides == null) {
_this.strides = _this.poolSize;
} else {
if (typeof args.strides === 'number') {
_this.strides = [args.strides];
} else if (Array.isArray(args.strides) && args.strides.length === 1 && typeof args.strides[0] === 'number') {
_this.strides = args.strides;
} else {
throw new ValueError("strides for 1D convolutional layer must be a number or an " + "Array of a single number, but received " + ("" + JSON.stringify(args.strides)));
}
}
assertPositiveInteger(_this.strides, 'strides');
_this.padding = args.padding == null ? 'valid' : args.padding;
checkPaddingMode(_this.padding);
_this.inputSpec = [new InputSpec({
ndim: 3
})];
return _this;
}
var _proto = Pooling1D.prototype;
_proto.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var length = convOutputLength(inputShape[1], this.poolSize[0], this.padding, this.strides[0]);
return [inputShape[0], length, inputShape[2]];
};
_proto.call = function call(inputs, kwargs) {
var _this2 = this;
return tidy(function () {
_this2.invokeCallHook(inputs, kwargs); // Add dummy last dimension.
inputs = expandDims$1(getExactlyOneTensor(inputs), 2);
var output = _this2.poolingFunction(getExactlyOneTensor(inputs), [_this2.poolSize[0], 1], [_this2.strides[0], 1], _this2.padding, 'channelsLast'); // Remove dummy last dimension.
return squeeze(output, [2]);
});
};
_proto.getConfig = function getConfig() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling1D;
}(Layer);
var MaxPooling1D = /*#__PURE__*/function (_Pooling1D) {
_inheritsLoose(MaxPooling1D, _Pooling1D);
function MaxPooling1D(args) {
return _Pooling1D.call(this, args) || this;
}
var _proto2 = MaxPooling1D.prototype;
_proto2.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
};
return MaxPooling1D;
}(Pooling1D);
/** @nocollapse */
MaxPooling1D.className = 'MaxPooling1D';
registerClass(MaxPooling1D);
var AveragePooling1D = /*#__PURE__*/function (_Pooling1D2) {
_inheritsLoose(AveragePooling1D, _Pooling1D2);
function AveragePooling1D(args) {
return _Pooling1D2.call(this, args) || this;
}
var _proto3 = AveragePooling1D.prototype;
_proto3.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
};
return AveragePooling1D;
}(Pooling1D);
/** @nocollapse */
AveragePooling1D.className = 'AveragePooling1D';
registerClass(AveragePooling1D);
/**
* Abstract class for different pooling 2D layers.
*/
var Pooling2D = /*#__PURE__*/function (_Layer2) {
_inheritsLoose(Pooling2D, _Layer2);
function Pooling2D(args) {
var _this3;
if (args.poolSize == null) {
args.poolSize = [2, 2];
}
_this3 = _Layer2.call(this, args) || this;
_this3.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize];
if (args.strides == null) {
_this3.strides = _this3.poolSize;
} else if (Array.isArray(args.strides)) {
if (args.strides.length !== 2) {
throw new ValueError("If the strides property of a 2D pooling layer is an Array, " + "it is expected to have a length of 2, but received length " + (args.strides.length + "."));
}
_this3.strides = args.strides;
} else {
// `config.strides` is a number.
_this3.strides = [args.strides, args.strides];
}
assertPositiveInteger(_this3.poolSize, 'poolSize');
assertPositiveInteger(_this3.strides, 'strides');
_this3.padding = args.padding == null ? 'valid' : args.padding;
_this3.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(_this3.dataFormat);
checkPaddingMode(_this3.padding);
_this3.inputSpec = [new InputSpec({
ndim: 4
})];
return _this3;
}
var _proto4 = Pooling2D.prototype;
_proto4.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var rows = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
var cols = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
rows = convOutputLength(rows, this.poolSize[0], this.padding, this.strides[0]);
cols = convOutputLength(cols, this.poolSize[1], this.padding, this.strides[1]);
if (this.dataFormat === 'channelsFirst') {
return [inputShape[0], inputShape[1], rows, cols];
} else {
return [inputShape[0], rows, cols, inputShape[3]];
}
};
_proto4.call = function call(inputs, kwargs) {
var _this4 = this;
return tidy(function () {
_this4.invokeCallHook(inputs, kwargs);
return _this4.poolingFunction(getExactlyOneTensor(inputs), _this4.poolSize, _this4.strides, _this4.padding, _this4.dataFormat);
});
};
_proto4.getConfig = function getConfig() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides,
dataFormat: this.dataFormat
};
var baseConfig = _Layer2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling2D;
}(Layer);
var MaxPooling2D = /*#__PURE__*/function (_Pooling2D) {
_inheritsLoose(MaxPooling2D, _Pooling2D);
function MaxPooling2D(args) {
return _Pooling2D.call(this, args) || this;
}
var _proto5 = MaxPooling2D.prototype;
_proto5.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, 'max');
};
return MaxPooling2D;
}(Pooling2D);
/** @nocollapse */
MaxPooling2D.className = 'MaxPooling2D';
registerClass(MaxPooling2D);
var AveragePooling2D = /*#__PURE__*/function (_Pooling2D2) {
_inheritsLoose(AveragePooling2D, _Pooling2D2);
function AveragePooling2D(args) {
return _Pooling2D2.call(this, args) || this;
}
var _proto6 = AveragePooling2D.prototype;
_proto6.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool2d(inputs, poolSize, strides, padding, dataFormat, 'avg');
};
return AveragePooling2D;
}(Pooling2D);
/** @nocollapse */
AveragePooling2D.className = 'AveragePooling2D';
registerClass(AveragePooling2D);
/**
* Abstract class for different pooling 3D layers.
*/
var Pooling3D = /*#__PURE__*/function (_Layer3) {
_inheritsLoose(Pooling3D, _Layer3);
function Pooling3D(args) {
var _this5;
if (args.poolSize == null) {
args.poolSize = [2, 2, 2];
}
_this5 = _Layer3.call(this, args) || this;
_this5.poolSize = Array.isArray(args.poolSize) ? args.poolSize : [args.poolSize, args.poolSize, args.poolSize];
if (args.strides == null) {
_this5.strides = _this5.poolSize;
} else if (Array.isArray(args.strides)) {
if (args.strides.length !== 3) {
throw new ValueError("If the strides property of a 3D pooling layer is an Array, " + "it is expected to have a length of 3, but received length " + (args.strides.length + "."));
}
_this5.strides = args.strides;
} else {
// `config.strides` is a number.
_this5.strides = [args.strides, args.strides, args.strides];
}
assertPositiveInteger(_this5.poolSize, 'poolSize');
assertPositiveInteger(_this5.strides, 'strides');
_this5.padding = args.padding == null ? 'valid' : args.padding;
_this5.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(_this5.dataFormat);
checkPaddingMode(_this5.padding);
_this5.inputSpec = [new InputSpec({
ndim: 5
})];
return _this5;
}
var _proto7 = Pooling3D.prototype;
_proto7.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var depths = this.dataFormat === 'channelsFirst' ? inputShape[2] : inputShape[1];
var rows = this.dataFormat === 'channelsFirst' ? inputShape[3] : inputShape[2];
var cols = this.dataFormat === 'channelsFirst' ? inputShape[4] : inputShape[3];
depths = convOutputLength(depths, this.poolSize[0], this.padding, this.strides[0]);
rows = convOutputLength(rows, this.poolSize[1], this.padding, this.strides[1]);
cols = convOutputLength(cols, this.poolSize[2], this.padding, this.strides[2]);
if (this.dataFormat === 'channelsFirst') {
return [inputShape[0], inputShape[1], depths, rows, cols];
} else {
return [inputShape[0], depths, rows, cols, inputShape[4]];
}
};
_proto7.call = function call(inputs, kwargs) {
var _this6 = this;
return tidy(function () {
_this6.invokeCallHook(inputs, kwargs);
return _this6.poolingFunction(getExactlyOneTensor(inputs), _this6.poolSize, _this6.strides, _this6.padding, _this6.dataFormat);
});
};
_proto7.getConfig = function getConfig() {
var config = {
poolSize: this.poolSize,
padding: this.padding,
strides: this.strides,
dataFormat: this.dataFormat
};
var baseConfig = _Layer3.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return Pooling3D;
}(Layer);
var MaxPooling3D = /*#__PURE__*/function (_Pooling3D) {
_inheritsLoose(MaxPooling3D, _Pooling3D);
function MaxPooling3D(args) {
return _Pooling3D.call(this, args) || this;
}
var _proto8 = MaxPooling3D.prototype;
_proto8.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(inputs, poolSize, strides, padding, dataFormat, 'max');
};
return MaxPooling3D;
}(Pooling3D);
/** @nocollapse */
MaxPooling3D.className = 'MaxPooling3D';
registerClass(MaxPooling3D);
var AveragePooling3D = /*#__PURE__*/function (_Pooling3D2) {
_inheritsLoose(AveragePooling3D, _Pooling3D2);
function AveragePooling3D(args) {
return _Pooling3D2.call(this, args) || this;
}
var _proto9 = AveragePooling3D.prototype;
_proto9.poolingFunction = function poolingFunction(inputs, poolSize, strides, padding, dataFormat) {
checkDataFormat(dataFormat);
checkPaddingMode(padding);
return pool3d(inputs, poolSize, strides, padding, dataFormat, 'avg');
};
return AveragePooling3D;
}(Pooling3D);
/** @nocollapse */
AveragePooling3D.className = 'AveragePooling3D';
registerClass(AveragePooling3D);
/**
* Abstract class for different global pooling 1D layers.
*/
var GlobalPooling1D = /*#__PURE__*/function (_Layer4) {
_inheritsLoose(GlobalPooling1D, _Layer4);
function GlobalPooling1D(args) {
var _this7;
_this7 = _Layer4.call(this, args) || this;
_this7.inputSpec = [new InputSpec({
ndim: 3
})];
return _this7;
}
var _proto10 = GlobalPooling1D.prototype;
_proto10.computeOutputShape = function computeOutputShape(inputShape) {
return [inputShape[0], inputShape[2]];
};
_proto10.call = function call(inputs, kwargs) {
throw new NotImplementedError();
};
return GlobalPooling1D;
}(Layer);
var GlobalAveragePooling1D = /*#__PURE__*/function (_GlobalPooling1D) {
_inheritsLoose(GlobalAveragePooling1D, _GlobalPooling1D);
function GlobalAveragePooling1D(args) {
return _GlobalPooling1D.call(this, args || {}) || this;
}
var _proto11 = GlobalAveragePooling1D.prototype;
_proto11.call = function call(inputs, kwargs) {
return tidy(function () {
var input = getExactlyOneTensor(inputs);
return mean(input, 1);
});
};
return GlobalAveragePooling1D;
}(GlobalPooling1D);
/** @nocollapse */
GlobalAveragePooling1D.className = 'GlobalAveragePooling1D';
registerClass(GlobalAveragePooling1D);
var GlobalMaxPooling1D = /*#__PURE__*/function (_GlobalPooling1D2) {
_inheritsLoose(GlobalMaxPooling1D, _GlobalPooling1D2);
function GlobalMaxPooling1D(args) {
return _GlobalPooling1D2.call(this, args || {}) || this;
}
var _proto12 = GlobalMaxPooling1D.prototype;
_proto12.call = function call(inputs, kwargs) {
return tidy(function () {
var input = getExactlyOneTensor(inputs);
return max$5(input, 1);
});
};
return GlobalMaxPooling1D;
}(GlobalPooling1D);
/** @nocollapse */
GlobalMaxPooling1D.className = 'GlobalMaxPooling1D';
registerClass(GlobalMaxPooling1D);
/**
* Abstract class for different global pooling 2D layers.
*/
var GlobalPooling2D = /*#__PURE__*/function (_Layer5) {
_inheritsLoose(GlobalPooling2D, _Layer5);
function GlobalPooling2D(args) {
var _this8;
_this8 = _Layer5.call(this, args) || this;
_this8.dataFormat = args.dataFormat == null ? 'channelsLast' : args.dataFormat;
checkDataFormat(_this8.dataFormat);
_this8.inputSpec = [new InputSpec({
ndim: 4
})];
return _this8;
}
var _proto13 = GlobalPooling2D.prototype;
_proto13.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = inputShape;
if (this.dataFormat === 'channelsLast') {
return [inputShape[0], inputShape[3]];
} else {
return [inputShape[0], inputShape[1]];
}
};
_proto13.call = function call(inputs, kwargs) {
throw new NotImplementedError();
};
_proto13.getConfig = function getConfig() {
var config = {
dataFormat: this.dataFormat
};
var baseConfig = _Layer5.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
return GlobalPooling2D;
}(Layer);
var GlobalAveragePooling2D = /*#__PURE__*/function (_GlobalPooling2D) {
_inheritsLoose(GlobalAveragePooling2D, _GlobalPooling2D);
function GlobalAveragePooling2D() {
return _GlobalPooling2D.apply(this, arguments) || this;
}
var _proto14 = GlobalAveragePooling2D.prototype;
_proto14.call = function call(inputs, kwargs) {
var _this9 = this;
return tidy(function () {
var input = getExactlyOneTensor(inputs);
if (_this9.dataFormat === 'channelsLast') {
return mean(input, [1, 2]);
} else {
return mean(input, [2, 3]);
}
});
};
return GlobalAveragePooling2D;
}(GlobalPooling2D);
/** @nocollapse */
GlobalAveragePooling2D.className = 'GlobalAveragePooling2D';
registerClass(GlobalAveragePooling2D);
var GlobalMaxPooling2D = /*#__PURE__*/function (_GlobalPooling2D2) {
_inheritsLoose(GlobalMaxPooling2D, _GlobalPooling2D2);
function GlobalMaxPooling2D() {
return _GlobalPooling2D2.apply(this, arguments) || this;
}
var _proto15 = GlobalMaxPooling2D.prototype;
_proto15.call = function call(inputs, kwargs) {
var _this10 = this;
return tidy(function () {
var input = getExactlyOneTensor(inputs);
if (_this10.dataFormat === 'channelsLast') {
return max$5(input, [1, 2]);
} else {
return max$5(input, [2, 3]);
}
});
};
return GlobalMaxPooling2D;
}(GlobalPooling2D);
/** @nocollapse */
GlobalMaxPooling2D.className = 'GlobalMaxPooling2D';
registerClass(GlobalMaxPooling2D);
/**
* Abstract wrapper base class.
*
* Wrappers take another layer and augment it in various ways.
* Do not use this class as a layer, it is only an abstract base class.
* Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
*/
var Wrapper = /*#__PURE__*/function (_Layer) {
_inheritsLoose(Wrapper, _Layer);
function Wrapper(args) {
var _this;
// Porting Note: In PyKeras, `self.layer` is set prior to the calling
// `super()`. But we can't do that here due to TypeScript's restriction.
// See: https://github.com/Microsoft/TypeScript/issues/8277
// As a result, we have to add checks in `get trainable()` and
// `set trainable()` below in order to prevent using `this.layer` when
// its value is `undefined`. The super constructor does use the getter
// and the setter of `this.layer`.
_this = _Layer.call(this, args) || this;
_this.layer = args.layer;
return _this;
}
var _proto = Wrapper.prototype;
_proto.build = function build(inputShape) {
this.built = true;
} // TODO(cais): Implement activityRegularizer getter.
;
// TODO(cais): Implement getLossesFor().
_proto.getWeights = function getWeights() {
return this.layer.getWeights();
};
_proto.setWeights = function setWeights(weights) {
this.layer.setWeights(weights);
};
_proto.getConfig = function getConfig() {
var config = {
'layer': {
'className': this.layer.getClassName(),
'config': this.layer.getConfig()
}
};
var baseConfig = _Layer.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
};
_proto.setFastWeightInitDuringBuild = function setFastWeightInitDuringBuild(value) {
_Layer.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.layer != null) {
this.layer.setFastWeightInitDuringBuild(value);
}
}
/** @nocollapse */
;
Wrapper.fromConfig = function fromConfig(cls, config, customObjects) {
if (customObjects === void 0) {
customObjects = {};
}
var layerConfig = config['layer'];
var layer = deserialize$1(layerConfig, customObjects);
delete config['layer'];
var newConfig = {
layer: layer
};
Object.assign(newConfig, config);
return new cls(newConfig);
};
_createClass(Wrapper, [{
key: "trainable",
get: function get() {
// Porting Note: the check of `this.layer` here is necessary due to the
// way the `constructor` of this class is written (see Porting Note
// above).
if (this.layer != null) {
return this.layer.trainable;
} else {
return false;
}
},
set: function set(value) {
// Porting Note: the check of `this.layer` here is necessary due to the
// way the `constructor` of this class is written (see Porting Note
// above).
if (this.layer != null) {
this.layer.trainable = value;
}
}
}, {
key: "trainableWeights",
get: function get() {
return this.layer.trainableWeights;
} // TODO(cais): Implement setter for trainableWeights.
}, {
key: "nonTrainableWeights",
get: function get() {
return this.layer.nonTrainableWeights;
} // TODO(cais): Implement setter for nonTrainableWeights.
}, {
key: "updates",
get: function get() {
// tslint:disable-next-line:no-any
return this.layer._updates;
} // TODO(cais): Implement getUpdatesFor().
}, {
key: "losses",
get: function get() {
return this.layer.losses;
}
}]);
return Wrapper;
}(Layer);
var TimeDistributed = /*#__PURE__*/function (_Wrapper) {
_inheritsLoose(TimeDistributed, _Wrapper);
function TimeDistributed(args) {
var _this2;
_this2 = _Wrapper.call(this, args) || this;
_this2.supportsMasking = true;
return _this2;
}
var _proto2 = TimeDistributed.prototype;
_proto2.build = function build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
if (inputShape.length < 3) {
throw new ValueError("TimeDistributed layer expects an input shape >= 3D, but received " + ("input shape " + JSON.stringify(inputShape)));
}
this.inputSpec = [{
shape: inputShape
}];
var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
if (!this.layer.built) {
this.layer.build(childInputShape);
this.layer.built = true;
}
_Wrapper.prototype.build.call(this, inputShape);
};
_proto2.computeOutputShape = function computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
var childInputShape = [inputShape[0]].concat(inputShape.slice(2));
var childOutputShape = this.layer.computeOutputShape(childInputShape);
var timesteps = inputShape[1];
return [childOutputShape[0], timesteps].concat(childOutputShape.slice(1));
};
_proto2.call = function call(inputs, kwargs) {
var _this3 = this;
return tidy(function () {
// TODO(cais): Add 'training' and 'useLearningPhase' to kwargs.
inputs = getExactlyOneTensor(inputs); // Porting Note: In tfjs-layers, `inputs` are always concrete tensor
// values. Hence the inputs can't have an undetermined first (batch)
// dimension, which is why we always use the K.rnn approach here.
var step = function step(inputs, states) {
// TODO(cais): Add useLearningPhase.
// NOTE(cais): `layer.call` may return a length-1 array of Tensor in
// some cases (e.g., `layer` is a `Sequential` instance), which is
// why `getExactlyOneTensor` is used below.
var output = getExactlyOneTensor(_this3.layer.call(inputs, kwargs));
return [output, []];
};
var rnnOutputs = rnn(step, inputs, [], false
/* goBackwards */
, null
/* mask */
, null
/* constants */
, false
/* unroll */
, true
/* needPerStepOutputs */
);
var y = rnnOutputs[1]; // TODO(cais): Add activity regularization.
// TODO(cais): Add useLearningPhase.
return y;
});
};
return TimeDistributed;
}(Wrapper);
/** @nocollapse */
TimeDistributed.className = 'TimeDistributed';
registerClass(TimeDistributed);
function checkBidirectionalMergeMode(value) {
checkStringTypeUnionValue(VALID_BIDIRECTIONAL_MERGE_MODES, 'BidirectionalMergeMode', value);
}
var DEFAULT_BIDIRECTIONAL_MERGE_MODE = 'concat';
var Bidirectional = /*#__PURE__*/function (_Wrapper2) {
_inheritsLoose(Bidirectional, _Wrapper2);
function Bidirectional(args) {
var _this4;
_this4 = _Wrapper2.call(this, args) || this; // Note: When creating `this.forwardLayer`, the original Layer object
// (`config.layer`) ought to be cloned. This is why we call
// `getConfig()` followed by `deserialize()`. Without this cloning,
// the layer names saved during serialization will incorrectly contain
// the 'forward_' prefix. In Python Keras, this is done using
// `copy.copy` (shallow copy), which does not have a simple equivalent
// in JavaScript. JavaScript's `Object.assign()` does not copy
// methods.
var layerConfig = args.layer.getConfig();
var forwDict = {};
forwDict['className'] = args.layer.getClassName();
forwDict['config'] = layerConfig;
_this4.forwardLayer = deserialize$1(forwDict);
layerConfig['goBackwards'] = layerConfig['goBackwards'] === true ? false : true;
var backDict = {};
backDict['className'] = args.layer.getClassName();
backDict['config'] = layerConfig;
_this4.backwardLayer = deserialize$1(backDict);
_this4.forwardLayer.name = 'forward_' + _this4.forwardLayer.name;
_this4.backwardLayer.name = 'backward_' + _this4.backwardLayer.name;
_this4.mergeMode = args.mergeMode === undefined ? DEFAULT_BIDIRECTIONAL_MERGE_MODE : args.mergeMode;
checkBidirectionalMergeMode(_this4.mergeMode);
if (args.weights) {
throw new NotImplementedError('weights support is not implemented for Bidirectional layer yet.');
}
_this4._stateful = args.layer.stateful;
_this4.returnSequences = args.layer.returnSequences;
_this4.returnState = args.layer.returnState;
_this4.supportsMasking = true;
_this4._trainable = true;
_this4.inputSpec = args.layer.inputSpec;
_this4.numConstants = null;
return _this4;
}
var _proto3 = Bidirectional.prototype;
_proto3.getWeights = function getWeights() {
return this.forwardLayer.getWeights().concat(this.backwardLayer.getWeights());
};
_proto3.setWeights = function setWeights(weights) {
var numWeights = weights.length;
var numeightsOver2 = Math.floor(numWeights / 2);
this.forwardLayer.setWeights(weights.slice(0, numeightsOver2));
this.backwardLayer.setWeights(weights.slice(numeightsOver2));
};
_proto3.computeOutputShape = function computeOutputShape(inputShape) {
var layerShapes = this.forwardLayer.computeOutputShape(inputShape);
if (!(Array.isArray(layerShapes) && Array.isArray(layerShapes[0]))) {
layerShapes = [layerShapes];
}
layerShapes = layerShapes;
var outputShape;
var outputShapes;
var stateShape;
if (this.returnState) {
stateShape = layerShapes.slice(1);
outputShape = layerShapes[0];
} else {
outputShape = layerShapes[0];
}
outputShape = outputShape;
if (this.mergeMode === 'concat') {
outputShape[outputShape.length - 1] *= 2;
outputShapes = [outputShape];
} else if (this.mergeMode == null) {
outputShapes = [outputShape, outputShape.slice()];
} else {
outputShapes = [outputShape];
}
if (this.returnState) {
if (this.mergeMode == null) {
return outputShapes.concat(stateShape).concat(stateShape.slice());
}
return [outputShape].concat(stateShape).concat(stateShape.slice());
}
return singletonOrArray(outputShapes);
};
_proto3.apply = function apply(inputs, kwargs) {
var initialState = kwargs == null ? null : kwargs['initialState'];
var constants = kwargs == null ? null : kwargs['constants'];
if (kwargs == null) {
kwargs = {};
}
var standardized = standardizeArgs(inputs, initialState, constants, this.numConstants);
inputs = standardized.inputs;
initialState = standardized.initialState;
constants = standardized.constants;
if (Array.isArray(inputs)) {
initialState = inputs.slice(1);
inputs = inputs[0];
}
if ((initialState == null || initialState.length === 0) && constants == null) {
return _Wrapper2.prototype.apply.call(this, inputs, kwargs);
}
var additionalInputs = [];
var additionalSpecs = [];
if (initialState != null) {
var numStates = initialState.length;
if (numStates % 2 > 0) {
throw new ValueError('When passing `initialState` to a Bidrectional RNN, ' + 'the state should be an Array containing the states of ' + 'the underlying RNNs.');
}
kwargs['initialState'] = initialState;
additionalInputs.push.apply(additionalInputs, initialState);
var stateSpecs = initialState.map(function (state) {
return new InputSpec({
shape: state.shape
});
});
this.forwardLayer.stateSpec = stateSpecs.slice(0, numStates / 2);
this.backwardLayer.stateSpec = stateSpecs.slice(numStates / 2);
additionalSpecs.push.apply(additionalSpecs, stateSpecs);
}
if (constants != null) {
throw new NotImplementedError('Support for constants in Bidirectional layers is not ' + 'implemented yet.');
}
var isSymbolicTensor = additionalInputs[0] instanceof SymbolicTensor;
for (var _i = 0, _additionalInputs = additionalInputs; _i < _additionalInputs.length; _i++) {
var tensor = _additionalInputs[_i];
if (tensor instanceof SymbolicTensor !== isSymbolicTensor) {
throw new ValueError('The initial state of a Bidirectional layer cannot be ' + 'specified as a mix of symbolic and non-symbolic tensors');
}
}
if (isSymbolicTensor) {
// Compute the full input and specs, including the states.
var fullInput = [inputs].concat(additionalInputs);
var fullInputSpec = this.inputSpec.concat(additionalSpecs); // Perform the call temporarily and replace inputSpec.
// Note: with initial states symbolic calls and non-symbolic calls to
// this method differ in how the initial states are passed. For
// symbolic calls, the initial states are passed in the first arg, as
// an Array of SymbolicTensors; for non-symbolic calls, they are
// passed in the second arg as a part of the kwargs. Hence the need to
// temporarily modify inputSpec here.
// TODO(cais): Make refactoring so that this hacky code below is no
// longer needed.
var originalInputSpec = this.inputSpec;
this.inputSpec = fullInputSpec;
var output = _Wrapper2.prototype.apply.call(this, fullInput, kwargs);
this.inputSpec = originalInputSpec;
return output;
} else {
return _Wrapper2.prototype.apply.call(this, inputs, kwargs);
}
};
_proto3.call = function call(inputs, kwargs) {
var _this5 = this;
return tidy(function () {
var initialState = kwargs['initialState'];
var y;
var yRev;
if (initialState == null) {
y = _this5.forwardLayer.call(inputs, kwargs);
yRev = _this5.backwardLayer.call(inputs, kwargs);
} else {
var forwardState = initialState.slice(0, initialState.length / 2);
var backwardState = initialState.slice(initialState.length / 2);
y = _this5.forwardLayer.call(inputs, Object.assign(kwargs, {
initialState: forwardState
}));
yRev = _this5.backwardLayer.call(inputs, Object.assign(kwargs, {
initialState: backwardState
}));
}
var states;
if (_this5.returnState) {
if (Array.isArray(y)) {
states = y.slice(1).concat(yRev.slice(1));
} else {}
y = y[0];
yRev = yRev[0];
}
if (_this5.returnSequences) {
yRev = reverse(yRev, 1);
}
var output;
if (_this5.mergeMode === 'concat') {
output = concatenate([y, yRev]);
} else if (_this5.mergeMode === 'sum') {
output = add$1(y, yRev);
} else if (_this5.mergeMode === 'ave') {
output = mul(.5, add$1(y, yRev));
} else if (_this5.mergeMode === 'mul') {
output = mul(y, yRev);
} else if (_this5.mergeMode == null) {
output = [y, yRev];
} // TODO(cais): Properly set learning phase.
if (_this5.returnState) {
if (_this5.mergeMode == null) {
return output.concat(states);
}
return [output].concat(states);
}
return output;
});
};
_proto3.resetStates = function resetStates(states) {
this.forwardLayer.resetStates();
this.backwardLayer.resetStates();
};
_proto3.build = function build(inputShape) {
var _this6 = this;
nameScope(this.forwardLayer.name, function () {
_this6.forwardLayer.build(inputShape);
});
nameScope(this.backwardLayer.name, function () {
_this6.backwardLayer.build(inputShape);
});
this.built = true;
};
_proto3.computeMask = function computeMask(inputs, mask) {
if (Array.isArray(mask)) {
mask = mask[0];
}
var outputMask;
if (this.returnSequences) {
if (this.mergeMode == null) {
outputMask = [mask, mask];
} else {
outputMask = mask;
}
} else {
if (this.mergeMode == null) {
outputMask = [null, null];
} else {
outputMask = null;
}
}
if (this.returnState) {
var states = this.forwardLayer.states;
var stateMask = states.map(function (state) {
return null;
});
if (Array.isArray(outputMask)) {
return outputMask.concat(stateMask).concat(stateMask);
} else {
return [outputMask].concat(stateMask).concat(stateMask);
}
} else {
return outputMask;
}
};
// TODO(cais): Implement constraints().
_proto3.setFastWeightInitDuringBuild = function setFastWeightInitDuringBuild(value) {
_Wrapper2.prototype.setFastWeightInitDuringBuild.call(this, value);
if (this.forwardLayer != null) {
this.forwardLayer.setFastWeightInitDuringBuild(value);
}
if (this.backwardLayer != null) {
this.backwardLayer.setFastWeightInitDuringBuild(value);
}
};
_proto3.getConfig = function getConfig() {
var config = {
'mergeMode': this.mergeMode
}; // TODO(cais): Add logic for `numConstants` once the property is added.
var baseConfig = _Wrapper2.prototype.getConfig.call(this);
Object.assign(config, baseConfig);
return config;
}
/** @nocollapse */
;
Bidirectional.fromConfig = function fromConfig(cls, config) {
var rnnLayer = deserialize$1(config['layer']);
delete config['layer']; // TODO(cais): Add logic for `numConstants` once the property is added.
if (config['numConstants'] != null) {
throw new NotImplementedError("Deserialization of a Bidirectional layer with numConstants " + "present is not supported yet.");
} // tslint:disable-next-line:no-any
var newConfig = config;
newConfig['layer'] = rnnLayer;
return new cls(newConfig);
};
_createClass(Bidirectional, [{
key: "trainable",
get: function get() {
return this._trainable;
},
set: function set(value) {
// Porting Note: the check of `this.layer` here is necessary due to the
// way the `constructor` of this class is written (see Porting Note
// above).
this._trainable = value;
if (this.forwardLayer != null) {
this.forwardLayer.trainable = value;
}
if (this.backwardLayer != null) {
this.backwardLayer.trainable = value;
}
}
}, {
key: "trainableWeights",
get: function get() {
return this.forwardLayer.trainableWeights.concat(this.backwardLayer.trainableWeights);
}
}, {
key: "nonTrainableWeights",
get: function get() {
return this.forwardLayer.nonTrainableWeights.concat(this.backwardLayer.nonTrainableWeights);
}
}]);
return Bidirectional;
}(Wrapper);
/** @nocollapse */
Bidirectional.className = 'Bidirectional';
registerClass(Bidirectional);
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
// class; include exectuable JavaScript code snippets where applicable
// (b/74074458).
// Input Layer.
/**
* An input layer is an entry point into a `tf.LayersModel`.
*
* `InputLayer` is generated automatically for `tf.Sequential`` models by
* specifying the `inputshape` or `batchInputShape` for the first layer. It
* should not be specified explicitly. However, it can be useful sometimes,
* e.g., when constructing a sequential model from a subset of another
* sequential model's layers. Like the code snippet below shows.
*
* ```js
* // Define a model which simply adds two inputs.
* const model1 = tf.sequential();
* model1.add(tf.layers.dense({inputShape: [4], units: 3, activation: 'relu'}));
* model1.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
* model1.summary();
* model1.predict(tf.zeros([1, 4])).print();
*
* // Construct another model, reusing the second layer of `model1` while
* // not using the first layer of `model1`. Note that you cannot add the second
* // layer of `model` directly as the first layer of the new sequential model,
* // because doing so will lead to an error related to the fact that the layer
* // is not an input layer. Instead, you need to create an `inputLayer` and add
* // it to the new sequential model before adding the reused layer.
* const model2 = tf.sequential();
* // Use an inputShape that matches the input shape of `model1`'s second
* // layer.
* model2.add(tf.layers.inputLayer({inputShape: [3]}));
* model2.add(model1.layers[1]);
* model2.summary();
* model2.predict(tf.zeros([1, 3])).print();
* ```
*
* @doc {heading: 'Layers', subheading: 'Inputs', namespace: 'layers'}
*/
function inputLayer(args) {
return new InputLayer(args);
} // Advanced Activation Layers.
/**
* Exponetial Linear Unit (ELU).
*
* It follows:
* `f(x) = alpha * (exp(x) - 1.) for x < 0`,
* `f(x) = x for x >= 0`.
*
* Input shape:
* Arbitrary. Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
* Output shape:
* Same shape as the input.
*
* References:
* - [Fast and Accurate Deep Network Learning by Exponential Linear Units
* (ELUs)](https://arxiv.org/abs/1511.07289v1)
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function elu$2(args) {
return new ELU(args);
}
/**
* Rectified Linear Unit activation function.
*
* Input shape:
* Arbitrary. Use the config field `inputShape` (Array of integers, does
* not include the sample axis) when using this layer as the first layer
* in a model.
*
* Output shape:
* Same shape as the input.
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function reLU(args) {
return new ReLU(args);
}
/**
* Leaky version of a rectified linear unit.
*
* It allows a small gradient when the unit is not active:
* `f(x) = alpha * x for x < 0.`
* `f(x) = x for x >= 0.`
*
* Input shape:
* Arbitrary. Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
* Output shape:
* Same shape as the input.
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function leakyReLU(args) {
return new LeakyReLU(args);
}
/**
* Parameterized version of a leaky rectified linear unit.
*
* It follows
* `f(x) = alpha * x for x < 0.`
* `f(x) = x for x >= 0.`
* wherein `alpha` is a trainable weight.
*
* Input shape:
* Arbitrary. Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
* Output shape:
* Same shape as the input.
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function prelu$1(args) {
return new PReLU(args);
}
/**
* Softmax activation layer.
*
* Input shape:
* Arbitrary. Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
* Output shape:
* Same shape as the input.
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function softmax$1(args) {
return new Softmax$2(args);
}
/**
* Thresholded Rectified Linear Unit.
*
* It follows:
* `f(x) = x for x > theta`,
* `f(x) = 0 otherwise`.
*
* Input shape:
* Arbitrary. Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
* Output shape:
* Same shape as the input.
*
* References:
* - [Zero-Bias Autoencoders and the Benefits of Co-Adapting
* Features](http://arxiv.org/abs/1402.3337)
*
* @doc {
* heading: 'Layers',
* subheading: 'Advanced Activation',
* namespace: 'layers'
* }
*/
function thresholdedReLU(args) {
return new ThresholdedReLU(args);
} // Convolutional Layers.
/**
* 1D convolution layer (e.g., temporal convolution).
*
* This layer creates a convolution kernel that is convolved
* with the layer input over a single spatial (or temporal) dimension
* to produce a tensor of outputs.
*
* If `use_bias` is True, a bias vector is created and added to the outputs.
*
* If `activation` is not `null`, it is applied to the outputs as well.
*
* When using this layer as the first layer in a model, provide an
* `inputShape` argument `Array` or `null`.
*
* For example, `inputShape` would be:
* - `[10, 128]` for sequences of 10 vectors of 128-dimensional vectors
* - `[null, 128]` for variable-length sequences of 128-dimensional vectors.
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function conv1d$2(args) {
return new Conv1D(args);
}
/**
* 2D convolution layer (e.g. spatial convolution over images).
*
* This layer creates a convolution kernel that is convolved
* with the layer input to produce a tensor of outputs.
*
* If `useBias` is True, a bias vector is created and added to the outputs.
*
* If `activation` is not `null`, it is applied to the outputs as well.
*
* When using this layer as the first layer in a model,
* provide the keyword argument `inputShape`
* (Array of integers, does not include the sample axis),
* e.g. `inputShape=[128, 128, 3]` for 128x128 RGB pictures
* in `dataFormat='channelsLast'`.
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function conv2d$3(args) {
return new Conv2D$1(args);
}
/**
* Transposed convolutional layer (sometimes called Deconvolution).
*
* The need for transposed convolutions generally arises
* from the desire to use a transformation going in the opposite direction of
* a normal convolution, i.e., from something that has the shape of the output
* of some convolution to something that has the shape of its input while
* maintaining a connectivity pattern that is compatible with said
* convolution.
*
* When using this layer as the first layer in a model, provide the
* configuration `inputShape` (`Array` of integers, does not include the
* sample axis), e.g., `inputShape: [128, 128, 3]` for 128x128 RGB pictures in
* `dataFormat: 'channelsLast'`.
*
* Input shape:
* 4D tensor with shape:
* `[batch, channels, rows, cols]` if `dataFormat` is `'channelsFirst'`.
* or 4D tensor with shape
* `[batch, rows, cols, channels]` if `dataFormat` is `'channelsLast`.
*
* Output shape:
* 4D tensor with shape:
* `[batch, filters, newRows, newCols]` if `dataFormat` is
* `'channelsFirst'`. or 4D tensor with shape:
* `[batch, newRows, newCols, filters]` if `dataFormat` is `'channelsLast'`.
*
* References:
* - [A guide to convolution arithmetic for deep
* learning](https://arxiv.org/abs/1603.07285v1)
* - [Deconvolutional
* Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf)
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function conv2dTranspose$1(args) {
return new Conv2DTranspose(args);
}
/**
* 3D convolution layer (e.g. spatial convolution over volumes).
*
* This layer creates a convolution kernel that is convolved
* with the layer input to produce a tensor of outputs.
*
* If `useBias` is True, a bias vector is created and added to the outputs.
*
* If `activation` is not `null`, it is applied to the outputs as well.
*
* When using this layer as the first layer in a model,
* provide the keyword argument `inputShape`
* (Array of integers, does not include the sample axis),
* e.g. `inputShape=[128, 128, 128, 1]` for 128x128x128 grayscale volumes
* in `dataFormat='channelsLast'`.
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function conv3d$2(args) {
return new Conv3D$1(args);
}
function conv3dTranspose$1(args) {
return new Conv3DTranspose(args);
}
/**
* Depthwise separable 2D convolution.
*
* Separable convolution consists of first performing
* a depthwise spatial convolution
* (which acts on each input channel separately)
* followed by a pointwise convolution which mixes together the resulting
* output channels. The `depthMultiplier` argument controls how many
* output channels are generated per input channel in the depthwise step.
*
* Intuitively, separable convolutions can be understood as
* a way to factorize a convolution kernel into two smaller kernels,
* or as an extreme version of an Inception block.
*
* Input shape:
* 4D tensor with shape:
* `[batch, channels, rows, cols]` if data_format='channelsFirst'
* or 4D tensor with shape:
* `[batch, rows, cols, channels]` if data_format='channelsLast'.
*
* Output shape:
* 4D tensor with shape:
* `[batch, filters, newRows, newCols]` if data_format='channelsFirst'
* or 4D tensor with shape:
* `[batch, newRows, newCols, filters]` if data_format='channelsLast'.
* `rows` and `cols` values might have changed due to padding.
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function separableConv2d$1(args) {
return new SeparableConv2D(args);
}
/**
* Cropping layer for 2D input (e.g., image).
*
* This layer can crop an input
* at the top, bottom, left and right side of an image tensor.
*
* Input shape:
* 4D tensor with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, rows, cols, channels]`
* - If `data_format` is `"channels_first"`:
* `[batch, channels, rows, cols]`.
*
* Output shape:
* 4D with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, croppedRows, croppedCols, channels]`
* - If `dataFormat` is `"channelsFirst"`:
* `[batch, channels, croppedRows, croppedCols]`.
*
* Examples
* ```js
*
* const model = tf.sequential();
* model.add(tf.layers.cropping2D({cropping:[[2, 2], [2, 2]],
* inputShape: [128, 128, 3]}));
* //now output shape is [batch, 124, 124, 3]
* ```
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function cropping2D(args) {
return new Cropping2D(args);
}
/**
* Upsampling layer for 2D inputs.
*
* Repeats the rows and columns of the data
* by size[0] and size[1] respectively.
*
*
* Input shape:
* 4D tensor with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, rows, cols, channels]`
* - If `dataFormat` is `"channelsFirst"`:
* `[batch, channels, rows, cols]`
*
* Output shape:
* 4D tensor with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, upsampledRows, upsampledCols, channels]`
* - If `dataFormat` is `"channelsFirst"`:
* `[batch, channels, upsampledRows, upsampledCols]`
*
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function upSampling2d(args) {
return new UpSampling2D(args);
} // Convolutional(depthwise) Layers.
/**
* Depthwise separable 2D convolution.
*
* Depthwise Separable convolutions consists in performing just the first step
* in a depthwise spatial convolution (which acts on each input channel
* separately). The `depthMultplier` argument controls how many output channels
* are generated per input channel in the depthwise step.
*
* @doc {heading: 'Layers', subheading: 'Convolutional', namespace: 'layers'}
*/
function depthwiseConv2d$3(args) {
return new DepthwiseConv2D(args);
} // Basic Layers.
/**
* Applies an activation function to an output.
*
* This layer applies element-wise activation function. Other layers, notably
* `dense` can also apply activation functions. Use this isolated activation
* function to extract the values before and after the
* activation. For instance:
*
* ```js
* const input = tf.input({shape: [5]});
* const denseLayer = tf.layers.dense({units: 1});
* const activationLayer = tf.layers.activation({activation: 'relu6'});
*
* // Obtain the output symbolic tensors by applying the layers in order.
* const denseOutput = denseLayer.apply(input);
* const activationOutput = activationLayer.apply(denseOutput);
*
* // Create the model based on the inputs.
* const model = tf.model({
* inputs: input,
* outputs: [denseOutput, activationOutput]
* });
*
* // Collect both outputs and print separately.
* const [denseOut, activationOut] = model.predict(tf.randomNormal([6, 5]));
* denseOut.print();
* activationOut.print();
* ```
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function activation(args) {
return new Activation$1(args);
}
/**
* Creates a dense (fully connected) layer.
*
* This layer implements the operation:
* `output = activation(dot(input, kernel) + bias)`
*
* `activation` is the element-wise activation function
* passed as the `activation` argument.
*
* `kernel` is a weights matrix created by the layer.
*
* `bias` is a bias vector created by the layer (only applicable if `useBias`
* is `true`).
*
* **Input shape:**
*
* nD `tf.Tensor` with shape: `(batchSize, ..., inputDim)`.
*
* The most common situation would be
* a 2D input with shape `(batchSize, inputDim)`.
*
* **Output shape:**
*
* nD tensor with shape: `(batchSize, ..., units)`.
*
* For instance, for a 2D input with shape `(batchSize, inputDim)`,
* the output would have shape `(batchSize, units)`.
*
* Note: if the input to the layer has a rank greater than 2, then it is
* flattened prior to the initial dot product with the kernel.
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function dense(args) {
return new Dense(args);
}
/**
* Applies
* [dropout](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf) to
* the input.
*
* Dropout consists in randomly setting a fraction `rate` of input units to 0 at
* each update during training time, which helps prevent overfitting.
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function dropout$2(args) {
return new Dropout(args);
}
/**
* Spatial 1D version of Dropout.
*
* This Layer type performs the same function as the Dropout layer, but it drops
* entire 1D feature maps instead of individual elements. For example, if an
* input example consists of 3 timesteps and the feature map for each timestep
* has a size of 4, a `spatialDropout1d` layer may zero out the feature maps
* of the 1st timesteps and 2nd timesteps completely while sparing all feature
* elements of the 3rd timestep.
*
* If adjacent frames (timesteps) are strongly correlated (as is normally the
* case in early convolution layers), regular dropout will not regularize the
* activation and will otherwise just result in merely an effective learning
* rate decrease. In this case, `spatialDropout1d` will help promote
* independence among feature maps and should be used instead.
*
* **Arguments:**
* rate: A floating-point number >=0 and <=1. Fraction of the input elements
* to drop.
*
* **Input shape:**
* 3D tensor with shape `(samples, timesteps, channels)`.
*
* **Output shape:**
* Same as the input shape.
*
* References:
* - [Efficient Object Localization Using Convolutional
* Networks](https://arxiv.org/abs/1411.4280)
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function spatialDropout1d(args) {
return new SpatialDropout1D(args);
}
/**
* Flattens the input. Does not affect the batch size.
*
* A `Flatten` layer flattens each batch in its inputs to 1D (making the output
* 2D).
*
* For example:
*
* ```js
* const input = tf.input({shape: [4, 3]});
* const flattenLayer = tf.layers.flatten();
* // Inspect the inferred output shape of the flatten layer, which
* // equals `[null, 12]`. The 2nd dimension is 4 * 3, i.e., the result of the
* // flattening. (The 1st dimension is the undermined batch size.)
* console.log(JSON.stringify(flattenLayer.apply(input).shape));
* ```
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function flatten$2(args) {
return new Flatten(args);
}
/**
* Repeats the input n times in a new dimension.
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.repeatVector({n: 4, inputShape: [2]}));
* const x = tf.tensor2d([[10, 20]]);
* // Use the model to do inference on a data point the model hasn't see
* model.predict(x).print();
* // output shape is now [batch, 2, 4]
* ```
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function repeatVector(args) {
return new RepeatVector(args);
}
/**
* Reshapes an input to a certain shape.
*
* ```js
* const input = tf.input({shape: [4, 3]});
* const reshapeLayer = tf.layers.reshape({targetShape: [2, 6]});
* // Inspect the inferred output shape of the Reshape layer, which
* // equals `[null, 2, 6]`. (The 1st dimension is the undermined batch size.)
* console.log(JSON.stringify(reshapeLayer.apply(input).shape));
* ```
*
* Input shape:
* Arbitrary, although all dimensions in the input shape must be fixed.
* Use the configuration `inputShape` when using this layer as the
* first layer in a model.
*
*
* Output shape:
* [batchSize, targetShape[0], targetShape[1], ...,
* targetShape[targetShape.length - 1]].
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function reshape$1(args) {
return new Reshape$1(args);
}
/**
* Permutes the dimensions of the input according to a given pattern.
*
* Useful for, e.g., connecting RNNs and convnets together.
*
* Example:
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.permute({
* dims: [2, 1],
* inputShape: [10, 64]
* }));
* console.log(model.outputShape);
* // Now model's output shape is [null, 64, 10], where null is the
* // unpermuted sample (batch) dimension.
* ```
*
* Input shape:
* Arbitrary. Use the configuration field `inputShape` when using this
* layer as the first layer in a model.
*
* Output shape:
* Same rank as the input shape, but with the dimensions re-ordered (i.e.,
* permuted) according to the `dims` configuration of this layer.
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function permute(args) {
return new Permute(args);
}
/**
* Maps positive integers (indices) into dense vectors of fixed size.
* eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
*
* **Input shape:** 2D tensor with shape: `[batchSize, sequenceLength]`.
*
* **Output shape:** 3D tensor with shape: `[batchSize, sequenceLength,
* outputDim]`.
*
* @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'}
*/
function embedding(args) {
return new Embedding(args);
} // Merge Layers.
/**
* Layer that performs element-wise addition on an `Array` of inputs.
*
* It takes as input a list of tensors, all of the same shape, and returns a
* single tensor (also of the same shape). The inputs are specified as an
* `Array` when the `apply` method of the `Add` layer instance is called. For
* example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const addLayer = tf.layers.add();
* const sum = addLayer.apply([input1, input2]);
* console.log(JSON.stringify(sum.shape));
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function add$3(args) {
return new Add$1(args);
}
/**
* Layer that performs element-wise averaging on an `Array` of inputs.
*
* It takes as input a list of tensors, all of the same shape, and returns a
* single tensor (also of the same shape). For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const averageLayer = tf.layers.average();
* const average = averageLayer.apply([input1, input2]);
* console.log(JSON.stringify(average.shape));
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function average$1(args) {
return new Average(args);
}
/**
* Layer that concatenates an `Array` of inputs.
*
* It takes a list of tensors, all of the same shape except for the
* concatenation axis, and returns a single tensor, the concatenation
* of all inputs. For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 3]});
* const concatLayer = tf.layers.concatenate();
* const output = concatLayer.apply([input1, input2]);
* console.log(JSON.stringify(output.shape));
* // You get [null, 2, 5], with the first dimension as the undetermined batch
* // dimension. The last dimension (5) is the result of concatenating the
* // last dimensions of the inputs (2 and 3).
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function concatenate$2(args) {
return new Concatenate(args);
}
/**
* Layer that computes the element-wise maximum an `Array` of inputs.
*
* It takes as input a list of tensors, all of the same shape and returns a
* single tensor (also of the same shape). For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const maxLayer = tf.layers.maximum();
* const max = maxLayer.apply([input1, input2]);
* console.log(JSON.stringify(max.shape));
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function maximum$2(args) {
return new Maximum$1(args);
}
/**
* Layer that computes the element-wise minimum of an `Array` of inputs.
*
* It takes as input a list of tensors, all of the same shape and returns a
* single tensor (also of the same shape). For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const minLayer = tf.layers.minimum();
* const min = minLayer.apply([input1, input2]);
* console.log(JSON.stringify(min.shape));
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function minimum$2(args) {
return new Minimum$1(args);
}
/**
* Layer that multiplies (element-wise) an `Array` of inputs.
*
* It takes as input an Array of tensors, all of the same
* shape, and returns a single tensor (also of the same shape).
* For example:
*
* ```js
* const input1 = tf.input({shape: [2, 2]});
* const input2 = tf.input({shape: [2, 2]});
* const input3 = tf.input({shape: [2, 2]});
* const multiplyLayer = tf.layers.multiply();
* const product = multiplyLayer.apply([input1, input2, input3]);
* console.log(product.shape);
* // You get [null, 2, 2], with the first dimension as the undetermined batch
* // dimension.
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function multiply$2(args) {
return new Multiply$1(args);
}
/**
* Layer that computes a dot product between samples in two tensors.
*
* E.g., if applied to a list of two tensors `a` and `b` both of shape
* `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
* where each entry at index `[i, 0]` will be the dot product between
* `a[i, :]` and `b[i, :]`.
*
* Example:
*
* ```js
* const dotLayer = tf.layers.dot({axes: -1});
* const x1 = tf.tensor2d([[10, 20], [30, 40]]);
* const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
*
* // Invoke the layer's apply() method in eager (imperative) mode.
* const y = dotLayer.apply([x1, x2]);
* y.print();
* ```
*
* @doc {heading: 'Layers', subheading: 'Merge', namespace: 'layers'}
*/
function dot$2(args) {
return new Dot(args);
} // Normalization Layers.
/**
* Batch normalization layer (Ioffe and Szegedy, 2014).
*
* Normalize the activations of the previous layer at each batch,
* i.e. applies a transformation that maintains the mean activation
* close to 0 and the activation standard deviation close to 1.
*
* Input shape:
* Arbitrary. Use the keyword argument `inputShape` (Array of integers, does
* not include the sample axis) when calling the constructor of this class,
* if this layer is used as a first layer in a model.
*
* Output shape:
* Same shape as input.
*
* References:
* - [Batch Normalization: Accelerating Deep Network Training by Reducing
* Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
*
* @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
*/
function batchNormalization$1(args) {
return new BatchNormalization(args);
}
/**
* Layer-normalization layer (Ba et al., 2016).
*
* Normalizes the activations of the previous layer for each given example in a
* batch independently, instead of across a batch like in `batchNormalization`.
* In other words, this layer applies a transformation that maintanis the mean
* activation within each example close to0 and activation variance close to 1.
*
* Input shape:
* Arbitrary. Use the argument `inputShape` when using this layer as the first
* layer in a model.
*
* Output shape:
* Same as input.
*
* References:
* - [Layer Normalization](https://arxiv.org/abs/1607.06450)
*
* @doc {heading: 'Layers', subheading: 'Normalization', namespace: 'layers'}
*/
function layerNormalization(args) {
return new LayerNormalization(args);
} // Padding Layers.
/**
* Zero-padding layer for 2D input (e.g., image).
*
* This layer can add rows and columns of zeros
* at the top, bottom, left and right side of an image tensor.
*
* Input shape:
* 4D tensor with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, rows, cols, channels]`
* - If `data_format` is `"channels_first"`:
* `[batch, channels, rows, cols]`.
*
* Output shape:
* 4D with shape:
* - If `dataFormat` is `"channelsLast"`:
* `[batch, paddedRows, paddedCols, channels]`
* - If `dataFormat` is `"channelsFirst"`:
* `[batch, channels, paddedRows, paddedCols]`.
*
* @doc {heading: 'Layers', subheading: 'Padding', namespace: 'layers'}
*/
function zeroPadding2d(args) {
return new ZeroPadding2D(args);
} // Pooling Layers.
/**
* Average pooling operation for spatial data.
*
* Input shape: `[batchSize, inLength, channels]`
*
* Output shape: `[batchSize, pooledLength, channels]`
*
* `tf.avgPool1d` is an alias.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function averagePooling1d(args) {
return new AveragePooling1D(args);
}
function avgPool1d(args) {
return averagePooling1d(args);
} // For backwards compatibility.
// See https://github.com/tensorflow/tfjs/issues/152
function avgPooling1d(args) {
return averagePooling1d(args);
}
/**
* Average pooling operation for spatial data.
*
* Input shape:
* - If `dataFormat === CHANNEL_LAST`:
* 4D tensor with shape:
* `[batchSize, rows, cols, channels]`
* - If `dataFormat === CHANNEL_FIRST`:
* 4D tensor with shape:
* `[batchSize, channels, rows, cols]`
*
* Output shape
* - If `dataFormat === CHANNEL_LAST`:
* 4D tensor with shape:
* `[batchSize, pooleRows, pooledCols, channels]`
* - If `dataFormat === CHANNEL_FIRST`:
* 4D tensor with shape:
* `[batchSize, channels, pooleRows, pooledCols]`
*
* `tf.avgPool2d` is an alias.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function averagePooling2d(args) {
return new AveragePooling2D(args);
}
function avgPool2d(args) {
return averagePooling2d(args);
} // For backwards compatibility.
// See https://github.com/tensorflow/tfjs/issues/152
function avgPooling2d(args) {
return averagePooling2d(args);
}
/**
* Average pooling operation for 3D data.
*
* Input shape
* - If `dataFormat === channelsLast`:
* 5D tensor with shape:
* `[batchSize, depths, rows, cols, channels]`
* - If `dataFormat === channelsFirst`:
* 4D tensor with shape:
* `[batchSize, channels, depths, rows, cols]`
*
* Output shape
* - If `dataFormat=channelsLast`:
* 5D tensor with shape:
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
* - If `dataFormat=channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function averagePooling3d(args) {
return new AveragePooling3D(args);
}
function avgPool3d$1(args) {
return averagePooling3d(args);
} // For backwards compatibility.
// See https://github.com/tensorflow/tfjs/issues/152
function avgPooling3d(args) {
return averagePooling3d(args);
}
/**
* Global average pooling operation for temporal data.
*
* Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
*
* Output Shape:2D tensor with shape: `[batchSize, features]`.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function globalAveragePooling1d(args) {
return new GlobalAveragePooling1D(args);
}
/**
* Global average pooling operation for spatial data.
*
* Input shape:
* - If `dataFormat` is `CHANNEL_LAST`:
* 4D tensor with shape: `[batchSize, rows, cols, channels]`.
* - If `dataFormat` is `CHANNEL_FIRST`:
* 4D tensor with shape: `[batchSize, channels, rows, cols]`.
*
* Output shape:
* 2D tensor with shape: `[batchSize, channels]`.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function globalAveragePooling2d(args) {
return new GlobalAveragePooling2D(args);
}
/**
* Global max pooling operation for temporal data.
*
* Input Shape: 3D tensor with shape: `[batchSize, steps, features]`.
*
* Output Shape:2D tensor with shape: `[batchSize, features]`.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function globalMaxPooling1d(args) {
return new GlobalMaxPooling1D(args);
}
/**
* Global max pooling operation for spatial data.
*
* Input shape:
* - If `dataFormat` is `CHANNEL_LAST`:
* 4D tensor with shape: `[batchSize, rows, cols, channels]`.
* - If `dataFormat` is `CHANNEL_FIRST`:
* 4D tensor with shape: `[batchSize, channels, rows, cols]`.
*
* Output shape:
* 2D tensor with shape: `[batchSize, channels]`.
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function globalMaxPooling2d(args) {
return new GlobalMaxPooling2D(args);
}
/**
* Max pooling operation for temporal data.
*
* Input shape: `[batchSize, inLength, channels]`
*
* Output shape: `[batchSize, pooledLength, channels]`
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function maxPooling1d(args) {
return new MaxPooling1D(args);
}
/**
* Max pooling operation for spatial data.
*
* Input shape
* - If `dataFormat === CHANNEL_LAST`:
* 4D tensor with shape:
* `[batchSize, rows, cols, channels]`
* - If `dataFormat === CHANNEL_FIRST`:
* 4D tensor with shape:
* `[batchSize, channels, rows, cols]`
*
* Output shape
* - If `dataFormat=CHANNEL_LAST`:
* 4D tensor with shape:
* `[batchSize, pooleRows, pooledCols, channels]`
* - If `dataFormat=CHANNEL_FIRST`:
* 4D tensor with shape:
* `[batchSize, channels, pooleRows, pooledCols]`
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function maxPooling2d(args) {
return new MaxPooling2D(args);
}
/**
* Max pooling operation for 3D data.
*
* Input shape
* - If `dataFormat === channelsLast`:
* 5D tensor with shape:
* `[batchSize, depths, rows, cols, channels]`
* - If `dataFormat === channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, depths, rows, cols]`
*
* Output shape
* - If `dataFormat=channelsLast`:
* 5D tensor with shape:
* `[batchSize, pooledDepths, pooledRows, pooledCols, channels]`
* - If `dataFormat=channelsFirst`:
* 5D tensor with shape:
* `[batchSize, channels, pooledDepths, pooledRows, pooledCols]`
*
* @doc {heading: 'Layers', subheading: 'Pooling', namespace: 'layers'}
*/
function maxPooling3d(args) {
return new MaxPooling3D(args);
} // Recurrent Layers.
/**
* Gated Recurrent Unit - Cho et al. 2014.
*
* This is an `RNN` layer consisting of one `GRUCell`. However, unlike
* the underlying `GRUCell`, the `apply` method of `SimpleRNN` operates
* on a sequence of inputs. The shape of the input (not including the first,
* batch dimension) needs to be at least 2-D, with the first dimension being
* time steps. For example:
*
* ```js
* const rnn = tf.layers.gru({units: 8, returnSequences: true});
*
* // Create an input with 10 time steps.
* const input = tf.input({shape: [10, 20]});
* const output = rnn.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the `GRUCell`'s number of units.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function gru(args) {
return new GRU(args);
}
/**
* Cell class for `GRU`.
*
* `GRUCell` is distinct from the `RNN` subclass `GRU` in that its
* `apply` method takes the input data of only a single time step and returns
* the cell's output at the time step, while `GRU` takes the input data
* over a number of time steps. For example:
*
* ```js
* const cell = tf.layers.gruCell({units: 2});
* const input = tf.input({shape: [10]});
* const output = cell.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10]: This is the cell's output at a single time step. The 1st
* // dimension is the unknown batch size.
* ```
*
* Instance(s) of `GRUCell` can be used to construct `RNN` layers. The
* most typical use of this workflow is to combine a number of cells into a
* stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
* RNN. For example:
*
* ```js
* const cells = [
* tf.layers.gruCell({units: 4}),
* tf.layers.gruCell({units: 8}),
* ];
* const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
*
* // Create an input with 10 time steps and a length-20 vector at each step.
* const input = tf.input({shape: [10, 20]});
* const output = rnn.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the last `gruCell`'s number of units.
* ```
*
* To create an `RNN` consisting of only *one* `GRUCell`, use the
* `tf.layers.gru`.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function gruCell(args) {
return new GRUCell(args);
}
/**
* Long-Short Term Memory layer - Hochreiter 1997.
*
* This is an `RNN` layer consisting of one `LSTMCell`. However, unlike
* the underlying `LSTMCell`, the `apply` method of `LSTM` operates
* on a sequence of inputs. The shape of the input (not including the first,
* batch dimension) needs to be at least 2-D, with the first dimension being
* time steps. For example:
*
* ```js
* const lstm = tf.layers.lstm({units: 8, returnSequences: true});
*
* // Create an input with 10 time steps.
* const input = tf.input({shape: [10, 20]});
* const output = lstm.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the `LSTMCell`'s number of units.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function lstm(args) {
return new LSTM(args);
}
/**
* Cell class for `LSTM`.
*
* `LSTMCell` is distinct from the `RNN` subclass `LSTM` in that its
* `apply` method takes the input data of only a single time step and returns
* the cell's output at the time step, while `LSTM` takes the input data
* over a number of time steps. For example:
*
* ```js
* const cell = tf.layers.lstmCell({units: 2});
* const input = tf.input({shape: [10]});
* const output = cell.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10]: This is the cell's output at a single time step. The 1st
* // dimension is the unknown batch size.
* ```
*
* Instance(s) of `LSTMCell` can be used to construct `RNN` layers. The
* most typical use of this workflow is to combine a number of cells into a
* stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
* RNN. For example:
*
* ```js
* const cells = [
* tf.layers.lstmCell({units: 4}),
* tf.layers.lstmCell({units: 8}),
* ];
* const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
*
* // Create an input with 10 time steps and a length-20 vector at each step.
* const input = tf.input({shape: [10, 20]});
* const output = rnn.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the last `lstmCell`'s number of units.
* ```
*
* To create an `RNN` consisting of only *one* `LSTMCell`, use the
* `tf.layers.lstm`.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function lstmCell(args) {
return new LSTMCell(args);
}
/**
* Fully-connected RNN where the output is to be fed back to input.
*
* This is an `RNN` layer consisting of one `SimpleRNNCell`. However, unlike
* the underlying `SimpleRNNCell`, the `apply` method of `SimpleRNN` operates
* on a sequence of inputs. The shape of the input (not including the first,
* batch dimension) needs to be at least 2-D, with the first dimension being
* time steps. For example:
*
* ```js
* const rnn = tf.layers.simpleRNN({units: 8, returnSequences: true});
*
* // Create an input with 10 time steps.
* const input = tf.input({shape: [10, 20]});
* const output = rnn.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the `SimpleRNNCell`'s number of units.
* ```
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function simpleRNN(args) {
return new SimpleRNN(args);
}
/**
* Cell class for `SimpleRNN`.
*
* `SimpleRNNCell` is distinct from the `RNN` subclass `SimpleRNN` in that its
* `apply` method takes the input data of only a single time step and returns
* the cell's output at the time step, while `SimpleRNN` takes the input data
* over a number of time steps. For example:
*
* ```js
* const cell = tf.layers.simpleRNNCell({units: 2});
* const input = tf.input({shape: [10]});
* const output = cell.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10]: This is the cell's output at a single time step. The 1st
* // dimension is the unknown batch size.
* ```
*
* Instance(s) of `SimpleRNNCell` can be used to construct `RNN` layers. The
* most typical use of this workflow is to combine a number of cells into a
* stacked RNN cell (i.e., `StackedRNNCell` internally) and use it to create an
* RNN. For example:
*
* ```js
* const cells = [
* tf.layers.simpleRNNCell({units: 4}),
* tf.layers.simpleRNNCell({units: 8}),
* ];
* const rnn = tf.layers.rnn({cell: cells, returnSequences: true});
*
* // Create an input with 10 time steps and a length-20 vector at each step.
* const input = tf.input({shape: [10, 20]});
* const output = rnn.apply(input);
*
* console.log(JSON.stringify(output.shape));
* // [null, 10, 8]: 1st dimension is unknown batch size; 2nd dimension is the
* // same as the sequence length of `input`, due to `returnSequences`: `true`;
* // 3rd dimension is the last `SimpleRNNCell`'s number of units.
* ```
*
* To create an `RNN` consisting of only *one* `SimpleRNNCell`, use the
* `tf.layers.simpleRNN`.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function simpleRNNCell(args) {
return new SimpleRNNCell(args);
}
/**
* Convolutional LSTM layer - Xingjian Shi 2015.
*
* This is an `ConvRNN2D` layer consisting of one `ConvLSTM2DCell`. However,
* unlike the underlying `ConvLSTM2DCell`, the `apply` method of `ConvLSTM2D`
* operates on a sequence of inputs. The shape of the input (not including the
* first, batch dimension) needs to be 4-D, with the first dimension being time
* steps. For example:
*
* ```js
* const filters = 3;
* const kernelSize = 3;
*
* const batchSize = 4;
* const sequenceLength = 2;
* const size = 5;
* const channels = 3;
*
* const inputShape = [batchSize, sequenceLength, size, size, channels];
* const input = tf.ones(inputShape);
*
* const layer = tf.layers.convLstm2d({filters, kernelSize});
*
* const output = layer.apply(input);
* ```
*/
/** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
function convLstm2d(args) {
return new ConvLSTM2D(args);
}
/**
* Cell class for `ConvLSTM2D`.
*
* `ConvLSTM2DCell` is distinct from the `ConvRNN2D` subclass `ConvLSTM2D` in
* that its `call` method takes the input data of only a single time step and
* returns the cell's output at the time step, while `ConvLSTM2D` takes the
* input data over a number of time steps. For example:
*
* ```js
* const filters = 3;
* const kernelSize = 3;
*
* const sequenceLength = 1;
* const size = 5;
* const channels = 3;
*
* const inputShape = [sequenceLength, size, size, channels];
* const input = tf.ones(inputShape);
*
* const cell = tf.layers.convLstm2dCell({filters, kernelSize});
*
* cell.build(input.shape);
*
* const outputSize = size - kernelSize + 1;
* const outShape = [sequenceLength, outputSize, outputSize, filters];
*
* const initialH = tf.zeros(outShape);
* const initialC = tf.zeros(outShape);
*
* const [o, h, c] = cell.call([input, initialH, initialC], {});
* ```
*/
/** @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'} */
function convLstm2dCell(args) {
return new ConvLSTM2DCell(args);
}
/**
* Base class for recurrent layers.
*
* Input shape:
* 3D tensor with shape `[batchSize, timeSteps, inputDim]`.
*
* Output shape:
* - if `returnState`, an Array of tensors (i.e., `tf.Tensor`s). The first
* tensor is the output. The remaining tensors are the states at the
* last time step, each with shape `[batchSize, units]`.
* - if `returnSequences`, the output will have shape
* `[batchSize, timeSteps, units]`.
* - else, the output will have shape `[batchSize, units]`.
*
* Masking:
* This layer supports masking for input data with a variable number
* of timesteps. To introduce masks to your data,
* use an embedding layer with the `mask_zero` parameter
* set to `True`.
*
* Notes on using statefulness in RNNs:
* You can set RNN layers to be 'stateful', which means that the states
* computed for the samples in one batch will be reused as initial states
* for the samples in the next batch. This assumes a one-to-one mapping
* between samples in different successive batches.
*
* To enable statefulness:
* - specify `stateful: true` in the layer constructor.
* - specify a fixed batch size for your model, by passing
* if sequential model:
* `batchInputShape=[...]` to the first layer in your model.
* else for functional model with 1 or more Input layers:
* `batchShape=[...]` to all the first layers in your model.
* This is the expected shape of your inputs *including the batch size*.
* It should be a tuple of integers, e.g. `(32, 10, 100)`.
* - specify `shuffle=False` when calling fit().
*
* To reset the states of your model, call `.resetStates()` on either
* a specific layer, or on your entire model.
*
* Note on specifying the initial state of RNNs
* You can specify the initial state of RNN layers symbolically by
* calling them with the option `initialState`. The value of
* `initialState` should be a tensor or list of tensors representing
* the initial state of the RNN layer.
*
* You can specify the initial state of RNN layers numerically by
* calling `resetStates` with the keyword argument `states`. The value of
* `states` should be a numpy array or list of numpy arrays representing
* the initial state of the RNN layer.
*
* Note on passing external constants to RNNs
* You can pass "external" constants to the cell using the `constants`
* keyword argument of `RNN.call` method. This requires that the `cell.call`
* method accepts the same keyword argument `constants`. Such constants
* can be used to conditon the cell transformation on additional static inputs
* (not changing over time), a.k.a an attention mechanism.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function rnn$1(args) {
return new RNN(args);
}
/**
* Wrapper allowing a stack of RNN cells to behave as a single cell.
*
* Used to implement efficient stacked RNNs.
*
* @doc {heading: 'Layers', subheading: 'Recurrent', namespace: 'layers'}
*/
function stackedRNNCells(args) {
return new StackedRNNCells(args);
} // Wrapper Layers.
/** @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'} */
function bidirectional(args) {
return new Bidirectional(args);
}
/**
* This wrapper applies a layer to every temporal slice of an input.
*
* The input should be at least 3D, and the dimension of the index `1` will be
* considered to be the temporal dimension.
*
* Consider a batch of 32 samples, where each sample is a sequence of 10 vectors
* of 16 dimensions. The batch input shape of the layer is then `[32, 10,
* 16]`, and the `inputShape`, not including the sample dimension, is
* `[10, 16]`.
*
* You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10
* timesteps, independently:
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.timeDistributed({
* layer: tf.layers.dense({units: 8}),
* inputShape: [10, 16],
* }));
*
* // Now model.outputShape = [null, 10, 8].
* // The output will then have shape `[32, 10, 8]`.
*
* // In subsequent layers, there is no need for `inputShape`:
* model.add(tf.layers.timeDistributed({layer: tf.layers.dense({units: 32})}));
* console.log(JSON.stringify(model.outputs[0].shape));
* // Now model.outputShape = [null, 10, 32].
* ```
*
* The output will then have shape `[32, 10, 32]`.
*
* `TimeDistributed` can be used with arbitrary layers, not just `Dense`, for
* instance a `Conv2D` layer.
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.timeDistributed({
* layer: tf.layers.conv2d({filters: 64, kernelSize: [3, 3]}),
* inputShape: [10, 299, 299, 3],
* }));
* console.log(JSON.stringify(model.outputs[0].shape));
* ```
*
* @doc {heading: 'Layers', subheading: 'Wrapper', namespace: 'layers'}
*/
function timeDistributed(args) {
return new TimeDistributed(args);
} // Aliases for pooling.
var globalMaxPool1d = globalMaxPooling1d;
var globalMaxPool2d = globalMaxPooling2d;
var maxPool1d = maxPooling1d;
var maxPool2d = maxPooling2d;
/**
* Apply additive zero-centered Gaussian noise.
*
* As it is a regularization layer, it is only active at training time.
*
* This is useful to mitigate overfitting
* (you could see it as a form of random data augmentation).
* Gaussian Noise (GS) is a natural choice as corruption process
* for real valued inputs.
*
* # Arguments
* stddev: float, standard deviation of the noise distribution.
*
* # Input shape
* Arbitrary. Use the keyword argument `input_shape`
* (tuple of integers, does not include the samples axis)
* when using this layer as the first layer in a model.
*
* # Output shape
* Same shape as input.
*
* @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
*/
function gaussianNoise(args) {
return new GaussianNoise(args);
}
/**
* Apply multiplicative 1-centered Gaussian noise.
*
* As it is a regularization layer, it is only active at training time.
*
* Arguments:
* - `rate`: float, drop probability (as with `Dropout`).
* The multiplicative noise will have
* standard deviation `sqrt(rate / (1 - rate))`.
*
* Input shape:
* Arbitrary. Use the keyword argument `inputShape`
* (tuple of integers, does not include the samples axis)
* when using this layer as the first layer in a model.
*
* Output shape:
* Same shape as input.
*
* References:
* - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting](
* http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
*
* @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
*/
function gaussianDropout(args) {
return new GaussianDropout(args);
}
/**
* Applies Alpha Dropout to the input.
*
* As it is a regularization layer, it is only active at training time.
*
* Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
* to their original values, in order to ensure the self-normalizing property
* even after this dropout.
* Alpha Dropout fits well to Scaled Exponential Linear Units
* by randomly setting activations to the negative saturation value.
*
* Arguments:
* - `rate`: float, drop probability (as with `Dropout`).
* The multiplicative noise will have
* standard deviation `sqrt(rate / (1 - rate))`.
* - `noise_shape`: A 1-D `Tensor` of type `int32`, representing the
* shape for randomly generated keep/drop flags.
*
* Input shape:
* Arbitrary. Use the keyword argument `inputShape`
* (tuple of integers, does not include the samples axis)
* when using this layer as the first layer in a model.
*
* Output shape:
* Same shape as input.
*
* References:
* - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
*
* @doc {heading: 'Layers', subheading: 'Noise', namespace: 'layers'}
*/
function alphaDropout(args) {
return new AlphaDropout(args);
}
/**
* Masks a sequence by using a mask value to skip timesteps.
*
* If all features for a given sample timestep are equal to `mask_value`,
* then the sample timestep will be masked (skipped) in all downstream layers
* (as long as they support masking).
*
* If any downstream layer does not support masking yet receives such
* an input mask, an exception will be raised.
*
* Arguments:
* - `maskValue`: Either None or mask value to skip.
*
* Input shape:
* Arbitrary. Use the keyword argument `inputShape`
* (tuple of integers, does not include the samples axis)
* when using this layer as the first layer in a model.
*
* Output shape:
* Same shape as input.
*
* @doc {heading: 'Layers', subheading: 'Mask', namespace: 'layers'}
*/
function masking(args) {
return new Masking(args);
}
var exports_layers = {
__proto__: null,
inputLayer: inputLayer,
elu: elu$2,
reLU: reLU,
leakyReLU: leakyReLU,
prelu: prelu$1,
softmax: softmax$1,
thresholdedReLU: thresholdedReLU,
conv1d: conv1d$2,
conv2d: conv2d$3,
conv2dTranspose: conv2dTranspose$1,
conv3d: conv3d$2,
conv3dTranspose: conv3dTranspose$1,
separableConv2d: separableConv2d$1,
cropping2D: cropping2D,
upSampling2d: upSampling2d,
depthwiseConv2d: depthwiseConv2d$3,
activation: activation,
dense: dense,
dropout: dropout$2,
spatialDropout1d: spatialDropout1d,
flatten: flatten$2,
repeatVector: repeatVector,
reshape: reshape$1,
permute: permute,
embedding: embedding,
add: add$3,
average: average$1,
concatenate: concatenate$2,
maximum: maximum$2,
minimum: minimum$2,
multiply: multiply$2,
dot: dot$2,
batchNormalization: batchNormalization$1,
layerNormalization: layerNormalization,
zeroPadding2d: zeroPadding2d,
averagePooling1d: averagePooling1d,
avgPool1d: avgPool1d,
avgPooling1d: avgPooling1d,
averagePooling2d: averagePooling2d,
avgPool2d: avgPool2d,
avgPooling2d: avgPooling2d,
averagePooling3d: averagePooling3d,
avgPool3d: avgPool3d$1,
avgPooling3d: avgPooling3d,
globalAveragePooling1d: globalAveragePooling1d,
globalAveragePooling2d: globalAveragePooling2d,
globalMaxPooling1d: globalMaxPooling1d,
globalMaxPooling2d: globalMaxPooling2d,
maxPooling1d: maxPooling1d,
maxPooling2d: maxPooling2d,
maxPooling3d: maxPooling3d,
gru: gru,
gruCell: gruCell,
lstm: lstm,
lstmCell: lstmCell,
simpleRNN: simpleRNN,
simpleRNNCell: simpleRNNCell,
convLstm2d: convLstm2d,
convLstm2dCell: convLstm2dCell,
rnn: rnn$1,
stackedRNNCells: stackedRNNCells,
bidirectional: bidirectional,
timeDistributed: timeDistributed,
globalMaxPool1d: globalMaxPool1d,
globalMaxPool2d: globalMaxPool2d,
maxPool1d: maxPool1d,
maxPool2d: maxPool2d,
Layer: Layer,
RNN: RNN,
RNNCell: RNNCell,
input: input,
gaussianNoise: gaussianNoise,
gaussianDropout: gaussianDropout,
alphaDropout: alphaDropout,
masking: masking
};
/**
* Binary accuracy metric function.
*
* `yTrue` and `yPred` can have 0-1 values. Example:
* ```js
* const x = tf.tensor2d([[1, 1, 1, 1], [0, 0, 0, 0]], [2, 4]);
* const y = tf.tensor2d([[1, 0, 1, 0], [0, 0, 0, 1]], [2, 4]);
* const accuracy = tf.metrics.binaryAccuracy(x, y);
* accuracy.print();
* ```
*
* `yTrue` and `yPred` can also have floating-number values between 0 and 1, in
* which case the values will be thresholded at 0.5 to yield 0-1 values (i.e.,
* a value >= 0.5 and <= 1.0 is interpreted as 1.
* )
* Example:
* ```js
* const x = tf.tensor1d([1, 1, 1, 1, 0, 0, 0, 0]);
* const y = tf.tensor1d([0.2, 0.4, 0.6, 0.8, 0.2, 0.3, 0.4, 0.7]);
* const accuracy = tf.metrics.binaryAccuracy(x, y);
* accuracy.print();
* ```
*
* @param yTrue Binary Tensor of truth.
* @param yPred Binary Tensor of prediction.
* @return Accuracy Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function binaryAccuracy$1(yTrue, yPred) {
return binaryAccuracy(yTrue, yPred);
}
/**
* Binary crossentropy metric function.
*
* Example:
* ```js
* const x = tf.tensor2d([[0], [1], [1], [1]]);
* const y = tf.tensor2d([[0], [0], [0.5], [1]]);
* const crossentropy = tf.metrics.binaryCrossentropy(x, y);
* crossentropy.print();
* ```
*
* @param yTrue Binary Tensor of truth.
* @param yPred Binary Tensor of prediction, probabilities for the `1` case.
* @return Accuracy Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function binaryCrossentropy$2(yTrue, yPred) {
return binaryCrossentropy$1(yTrue, yPred);
}
/**
* Sparse categorical accuracy metric function.
*
* Example:
* ```js
*
* const yTrue = tf.tensor1d([1, 1, 2, 2, 0]);
* const yPred = tf.tensor2d(
* [[0, 1, 0], [1, 0, 0], [0, 0.4, 0.6], [0, 0.6, 0.4], [0.7, 0.3, 0]]);
* const crossentropy = tf.metrics.sparseCategoricalAccuracy(yTrue, yPred);
* crossentropy.print();
* ```
*
* @param yTrue True labels: indices.
* @param yPred Predicted probabilities or logits.
* @returns Accuracy tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function sparseCategoricalAccuracy$1(yTrue, yPred) {
return sparseCategoricalAccuracy(yTrue, yPred);
}
/**
* Categorical accuracy metric function.
*
* Example:
* ```js
* const x = tf.tensor2d([[0, 0, 0, 1], [0, 0, 0, 1]]);
* const y = tf.tensor2d([[0.1, 0.8, 0.05, 0.05], [0.1, 0.05, 0.05, 0.8]]);
* const accuracy = tf.metrics.categoricalAccuracy(x, y);
* accuracy.print();
* ```
*
* @param yTrue Binary Tensor of truth: one-hot encoding of categories.
* @param yPred Binary Tensor of prediction: probabilities or logits for the
* same categories as in `yTrue`.
* @return Accuracy Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function categoricalAccuracy$1(yTrue, yPred) {
return categoricalAccuracy(yTrue, yPred);
}
/**
* Categorical crossentropy between an output tensor and a target tensor.
*
* @param target A tensor of the same shape as `output`.
* @param output A tensor resulting from a softmax (unless `fromLogits` is
* `true`, in which case `output` is expected to be the logits).
* @param fromLogits Boolean, whether `output` is the result of a softmax, or is
* a tensor of logits.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function categoricalCrossentropy$2(yTrue, yPred) {
return categoricalCrossentropy$1(yTrue, yPred);
}
/**
* Computes the precision of the predictions with respect to the labels.
*
* Example:
* ```js
* const x = tf.tensor2d(
* [
* [0, 0, 0, 1],
* [0, 1, 0, 0],
* [0, 0, 0, 1],
* [1, 0, 0, 0],
* [0, 0, 1, 0]
* ]
* );
*
* const y = tf.tensor2d(
* [
* [0, 0, 1, 0],
* [0, 1, 0, 0],
* [0, 0, 0, 1],
* [0, 1, 0, 0],
* [0, 1, 0, 0]
* ]
* );
*
* const precision = tf.metrics.precision(x, y);
* precision.print();
* ```
*
* @param yTrue The ground truth values. Expected to be contain only 0-1 values.
* @param yPred The predicted values. Expected to be contain only 0-1 values.
* @return Precision Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function precision$1(yTrue, yPred) {
return precision(yTrue, yPred);
}
/**
* Computes the recall of the predictions with respect to the labels.
*
* Example:
* ```js
* const x = tf.tensor2d(
* [
* [0, 0, 0, 1],
* [0, 1, 0, 0],
* [0, 0, 0, 1],
* [1, 0, 0, 0],
* [0, 0, 1, 0]
* ]
* );
*
* const y = tf.tensor2d(
* [
* [0, 0, 1, 0],
* [0, 1, 0, 0],
* [0, 0, 0, 1],
* [0, 1, 0, 0],
* [0, 1, 0, 0]
* ]
* );
*
* const recall = tf.metrics.recall(x, y);
* recall.print();
* ```
*
* @param yTrue The ground truth values. Expected to be contain only 0-1 values.
* @param yPred The predicted values. Expected to be contain only 0-1 values.
* @return Recall Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function recall$1(yTrue, yPred) {
return recall(yTrue, yPred);
}
/**
* Loss or metric function: Cosine proximity.
*
* Mathematically, cosine proximity is defined as:
* `-sum(l2Normalize(yTrue) * l2Normalize(yPred))`,
* wherein `l2Normalize()` normalizes the L2 norm of the input to 1 and `*`
* represents element-wise multiplication.
*
* ```js
* const yTrue = tf.tensor2d([[1, 0], [1, 0]]);
* const yPred = tf.tensor2d([[1 / Math.sqrt(2), 1 / Math.sqrt(2)], [0, 1]]);
* const proximity = tf.metrics.cosineProximity(yTrue, yPred);
* proximity.print();
* ```
*
* @param yTrue Truth Tensor.
* @param yPred Prediction Tensor.
* @return Cosine proximity Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function cosineProximity$1(yTrue, yPred) {
return cosineProximity(yTrue, yPred);
}
/**
* Loss or metric function: Mean absolute error.
*
* Mathematically, mean absolute error is defined as:
* `mean(abs(yPred - yTrue))`,
* wherein the `mean` is applied over feature dimensions.
*
* ```js
* const yTrue = tf.tensor2d([[0, 1], [0, 0], [2, 3]]);
* const yPred = tf.tensor2d([[0, 1], [0, 1], [-2, -3]]);
* const mse = tf.metrics.meanAbsoluteError(yTrue, yPred);
* mse.print();
* ```
*
* @param yTrue Truth Tensor.
* @param yPred Prediction Tensor.
* @return Mean absolute error Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function meanAbsoluteError$1(yTrue, yPred) {
return meanAbsoluteError(yTrue, yPred);
}
/**
* Loss or metric function: Mean absolute percentage error.
*
* ```js
* const yTrue = tf.tensor2d([[0, 1], [10, 20]]);
* const yPred = tf.tensor2d([[0, 1], [11, 24]]);
* const mse = tf.metrics.meanAbsolutePercentageError(yTrue, yPred);
* mse.print();
* ```
*
* Aliases: `tf.metrics.MAPE`, `tf.metrics.mape`.
*
* @param yTrue Truth Tensor.
* @param yPred Prediction Tensor.
* @return Mean absolute percentage error Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function meanAbsolutePercentageError$1(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
function MAPE$2(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
function mape$2(yTrue, yPred) {
return meanAbsolutePercentageError(yTrue, yPred);
}
/**
* Loss or metric function: Mean squared error.
*
* ```js
* const yTrue = tf.tensor2d([[0, 1], [3, 4]]);
* const yPred = tf.tensor2d([[0, 1], [-3, -4]]);
* const mse = tf.metrics.meanSquaredError(yTrue, yPred);
* mse.print();
* ```
*
* Aliases: `tf.metrics.MSE`, `tf.metrics.mse`.
*
* @param yTrue Truth Tensor.
* @param yPred Prediction Tensor.
* @return Mean squared error Tensor.
*
* @doc {heading: 'Metrics', namespace: 'metrics'}
*/
function meanSquaredError$2(yTrue, yPred) {
return meanSquaredError$1(yTrue, yPred);
}
function MSE$2(yTrue, yPred) {
return meanSquaredError$1(yTrue, yPred);
}
function mse$2(yTrue, yPred) {
return meanSquaredError$1(yTrue, yPred);
}
var exports_metrics = {
__proto__: null,
binaryAccuracy: binaryAccuracy$1,
binaryCrossentropy: binaryCrossentropy$2,
sparseCategoricalAccuracy: sparseCategoricalAccuracy$1,
categoricalAccuracy: categoricalAccuracy$1,
categoricalCrossentropy: categoricalCrossentropy$2,
precision: precision$1,
recall: recall$1,
cosineProximity: cosineProximity$1,
meanAbsoluteError: meanAbsoluteError$1,
meanAbsolutePercentageError: meanAbsolutePercentageError$1,
MAPE: MAPE$2,
mape: mape$2,
meanSquaredError: meanSquaredError$2,
MSE: MSE$2,
mse: mse$2
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
var exports_models = {
__proto__: null,
modelFromJSON: modelFromJSON
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* Regularizer for L1 and L2 regularization.
*
* Adds a term to the loss to penalize large weights:
* loss += sum(l1 * abs(x)) + sum(l2 * x^2)
*
* @doc {heading: 'Regularizers', namespace: 'regularizers'}
*/
function l1l2(config) {
return new L1L2(config);
}
/**
* Regularizer for L1 regularization.
*
* Adds a term to the loss to penalize large weights:
* loss += sum(l1 * abs(x))
* @param args l1 config.
*
* @doc {heading: 'Regularizers', namespace: 'regularizers'}
*/
function l1$1(config) {
return l1(config);
}
/**
* Regularizer for L2 regularization.
*
* Adds a term to the loss to penalize large weights:
* loss += sum(l2 * x^2)
* @param args l2 config.
*
* @doc {heading: 'Regularizers', namespace: 'regularizers'}
*/
function l2$1(config) {
return l2(config);
}
var exports_regularizers = {
__proto__: null,
l1l2: l1l2,
l1: l1$1,
l2: l2$1
};
var Callback = /*#__PURE__*/function (_BaseCallback) {
_inheritsLoose(Callback, _BaseCallback);
function Callback() {
var _this;
_this = _BaseCallback.apply(this, arguments) || this;
/** Instance of `keras.models.Model`. Reference of the model being trained. */
_this.model = null;
return _this;
}
var _proto = Callback.prototype;
_proto.setModel = function setModel(model) {
if (!(model instanceof LayersModel)) {
throw new Error('model must be a LayersModel, not some other Container');
}
this.model = model;
};
return Callback;
}(BaseCallback);
function less$1(currVal, prevVal) {
return currVal < prevVal;
}
function greater$1(currVal, prevVal) {
return currVal > prevVal;
}
/**
* A Callback that stops training when a monitored quantity has stopped
* improving.
*/
var EarlyStopping = /*#__PURE__*/function (_Callback) {
_inheritsLoose(EarlyStopping, _Callback);
function EarlyStopping(args) {
var _this2;
_this2 = _Callback.call(this) || this;
if (args == null) {
args = {};
}
if (args.restoreBestWeights) {
throw new NotImplementedError('restoreBestWeights = True is not implemented in EarlyStopping yet.');
}
_this2.monitor = args.monitor || 'val_loss';
_this2.minDelta = Math.abs(args.minDelta || 0);
_this2.patience = args.patience || 0;
_this2.verbose = args.verbose || 0;
_this2.mode = args.mode || 'auto';
_this2.baseline = args.baseline;
if (['auto', 'min', 'max'].indexOf(_this2.mode) === -1) {
console.warn("EarlyStopping mode '" + _this2.mode + "' is invalid. " + "Falling back to mode 'auto'.");
_this2.mode = 'auto';
}
if (_this2.mode === 'min') {
_this2.monitorFunc = less$1;
} else if (_this2.mode === 'max') {
_this2.monitorFunc = greater$1;
} else {
// For mode === 'auto'.
if (_this2.monitor.indexOf('acc') !== -1) {
_this2.monitorFunc = greater$1;
} else {
_this2.monitorFunc = less$1;
}
}
if (_this2.monitorFunc === less$1) {
_this2.minDelta *= -1;
}
return _this2;
}
var _proto2 = EarlyStopping.prototype;
_proto2.onTrainBegin = /*#__PURE__*/function () {
var _onTrainBegin = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(logs) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
this.wait = 0;
this.stoppedEpoch = 0;
if (this.baseline != null) {
this.best = this.baseline;
} else {
this.best = this.monitorFunc === less$1 ? Infinity : -Infinity;
}
case 3:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function onTrainBegin(_x) {
return _onTrainBegin.apply(this, arguments);
}
return onTrainBegin;
}();
_proto2.onEpochEnd = /*#__PURE__*/function () {
var _onEpochEnd = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(epoch, logs) {
var current;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return resolveScalarsInLogs(logs);
case 2:
current = this.getMonitorValue(logs);
if (!(current == null)) {
_context2.next = 5;
break;
}
return _context2.abrupt("return");
case 5:
if (this.monitorFunc(current - this.minDelta, this.best)) {
this.best = current;
this.wait = 0; // TODO(cais): Logic for restoreBestWeights.
} else {
this.wait++;
if (this.wait >= this.patience) {
this.stoppedEpoch = epoch;
this.model.stopTraining = true;
} // TODO(cais): Logic for restoreBestWeights.
}
case 6:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function onEpochEnd(_x2, _x3) {
return _onEpochEnd.apply(this, arguments);
}
return onEpochEnd;
}();
_proto2.onTrainEnd = /*#__PURE__*/function () {
var _onTrainEnd = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(logs) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (this.stoppedEpoch > 0 && this.verbose) {
console.log("Epoch " + this.stoppedEpoch + ": early stopping.");
}
case 1:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function onTrainEnd(_x4) {
return _onTrainEnd.apply(this, arguments);
}
return onTrainEnd;
}();
_proto2.getMonitorValue = function getMonitorValue(logs) {
if (logs == null) {
logs = {};
}
var monitorValue = logs[this.monitor];
if (monitorValue == null) {
console.warn("Metric for EarlyStopping " + this.monitor + " is not available. " + ("Available metrics are: " + Object.keys(logs)));
}
return monitorValue;
};
return EarlyStopping;
}(Callback);
/**
* Factory function for a Callback that stops training when a monitored
* quantity has stopped improving.
*
* Early stopping is a type of regularization, and protects model against
* overfitting.
*
* The following example based on fake data illustrates how this callback
* can be used during `tf.LayersModel.fit()`:
*
* ```js
* const model = tf.sequential();
* model.add(tf.layers.dense({
* units: 3,
* activation: 'softmax',
* kernelInitializer: 'ones',
* inputShape: [2]
* }));
* const xs = tf.tensor2d([1, 2, 3, 4], [2, 2]);
* const ys = tf.tensor2d([[1, 0, 0], [0, 1, 0]], [2, 3]);
* const xsVal = tf.tensor2d([4, 3, 2, 1], [2, 2]);
* const ysVal = tf.tensor2d([[0, 0, 1], [0, 1, 0]], [2, 3]);
* model.compile(
* {loss: 'categoricalCrossentropy', optimizer: 'sgd', metrics: ['acc']});
*
* // Without the EarlyStopping callback, the val_acc value would be:
* // 0.5, 0.5, 0.5, 0.5, ...
* // With val_acc being monitored, training should stop after the 2nd epoch.
* const history = await model.fit(xs, ys, {
* epochs: 10,
* validationData: [xsVal, ysVal],
* callbacks: tf.callbacks.earlyStopping({monitor: 'val_acc'})
* });
*
* // Expect to see a length-2 array.
* console.log(history.history.val_acc);
* ```
*
* @doc {
* heading: 'Callbacks',
* namespace: 'callbacks'
* }
*/
function earlyStopping(args) {
return new EarlyStopping(args);
}
var callbacks = {
earlyStopping: earlyStopping
};
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
/** DataType enum. */
var DataType;
(function (DataType) {
DataType[DataType["DT_INVALID"] = 0] = "DT_INVALID";
DataType[DataType["DT_FLOAT"] = 1] = "DT_FLOAT";
DataType[DataType["DT_DOUBLE"] = 2] = "DT_DOUBLE";
DataType[DataType["DT_INT32"] = 3] = "DT_INT32";
DataType[DataType["DT_UINT8"] = 4] = "DT_UINT8";
DataType[DataType["DT_INT16"] = 5] = "DT_INT16";
DataType[DataType["DT_INT8"] = 6] = "DT_INT8";
DataType[DataType["DT_STRING"] = 7] = "DT_STRING";
DataType[DataType["DT_COMPLEX64"] = 8] = "DT_COMPLEX64";
DataType[DataType["DT_INT64"] = 9] = "DT_INT64";
DataType[DataType["DT_BOOL"] = 10] = "DT_BOOL";
DataType[DataType["DT_QINT8"] = 11] = "DT_QINT8";
DataType[DataType["DT_QUINT8"] = 12] = "DT_QUINT8";
DataType[DataType["DT_QINT32"] = 13] = "DT_QINT32";
DataType[DataType["DT_BFLOAT16"] = 14] = "DT_BFLOAT16";
DataType[DataType["DT_FLOAT_REF"] = 101] = "DT_FLOAT_REF";
DataType[DataType["DT_DOUBLE_REF"] = 102] = "DT_DOUBLE_REF";
DataType[DataType["DT_INT32_REF"] = 103] = "DT_INT32_REF";
DataType[DataType["DT_UINT8_REF"] = 104] = "DT_UINT8_REF";
DataType[DataType["DT_INT16_REF"] = 105] = "DT_INT16_REF";
DataType[DataType["DT_INT8_REF"] = 106] = "DT_INT8_REF";
DataType[DataType["DT_STRING_REF"] = 107] = "DT_STRING_REF";
DataType[DataType["DT_COMPLEX64_REF"] = 108] = "DT_COMPLEX64_REF";
DataType[DataType["DT_INT64_REF"] = 109] = "DT_INT64_REF";
DataType[DataType["DT_BOOL_REF"] = 110] = "DT_BOOL_REF";
DataType[DataType["DT_QINT8_REF"] = 111] = "DT_QINT8_REF";
DataType[DataType["DT_QUINT8_REF"] = 112] = "DT_QUINT8_REF";
DataType[DataType["DT_QINT32_REF"] = 113] = "DT_QINT32_REF";
DataType[DataType["DT_BFLOAT16_REF"] = 114] = "DT_BFLOAT16_REF";
})(DataType || (DataType = {}));
var SaverDef;
(function (SaverDef) {
/** CheckpointFormatVersion enum. */
var CheckpointFormatVersion;
(function (CheckpointFormatVersion) {
CheckpointFormatVersion[CheckpointFormatVersion["LEGACY"] = 0] = "LEGACY";
CheckpointFormatVersion[CheckpointFormatVersion["V1"] = 1] = "V1";
CheckpointFormatVersion[CheckpointFormatVersion["V2"] = 2] = "V2";
})(CheckpointFormatVersion = SaverDef.CheckpointFormatVersion || (SaverDef.CheckpointFormatVersion = {}));
})(SaverDef || (SaverDef = {}));
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CUSTOM_OPS = {};
/**
* Register an Op for graph model executor. This allow you to register
* TensorFlow custom op or override existing op.
*
* Here is an example of registering a new MatMul Op.
* ```js
* const customMatmul = (node) =>
* tf.matMul(
* node.inputs[0], node.inputs[1],
* node.attrs['transpose_a'], node.attrs['transpose_b']);
*
* tf.registerOp('MatMul', customMatmul);
* ```
* The inputs and attrs of the node object is based on the TensorFlow op
* registry.
*
* @param name The Tensorflow Op name.
* @param opFunc An op function which is called with the current graph node
* during execution and needs to return a tensor or a list of tensors. The node
* has the following attributes:
* - attr: A map from attribute name to its value
* - inputs: A list of input tensors
*
* @doc {heading: 'Models', subheading: 'Op Registry'}
*/
function registerOp(name, opFunc) {
var opMapper = {
tfOpName: name,
category: 'custom',
inputs: [],
attrs: [],
customExecutor: opFunc
};
CUSTOM_OPS[name] = opMapper;
}
/**
* Retrieve the OpMapper object for the registered op.
*
* @param name The Tensorflow Op name.
*
* @doc {heading: 'Models', subheading: 'Op Registry'}
*/
function getRegisteredOp(name) {
return CUSTOM_OPS[name];
}
/**
* Deregister the Op for graph model executor.
*
* @param name The Tensorflow Op name.
*
* @doc {heading: 'Models', subheading: 'Op Registry'}
*/
function deregisterOp(name) {
delete CUSTOM_OPS[name];
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getParamValue(paramName, node, tensorMap, context, resourceManager) {
var inputParam = node.inputParams[paramName];
if (inputParam && inputParam.inputIndexStart !== undefined) {
var start = inputParam.inputIndexStart;
var end = inputParam.inputIndexEnd === 0 ? undefined : inputParam.inputIndexEnd === undefined ? start + 1 : inputParam.inputIndexEnd;
if (inputParam.type === 'tensor') {
return getTensor(node.inputNames[inputParam.inputIndexStart], tensorMap, context, resourceManager);
}
if (inputParam.type === 'tensors') {
var inputs = node.inputNames.slice(start, end);
return inputs.map(function (name) {
return getTensor(name, tensorMap, context, resourceManager);
});
}
var tensor = getTensor(node.inputNames.slice(start)[0], tensorMap, context, resourceManager);
var data = tensor.dataSync();
return inputParam.type === 'number' ? data[0] : toNestedArray(tensor.shape, data);
}
var attrParam = node.attrParams[paramName];
return attrParam && attrParam.value;
}
/**
* Retrieve the tensor from tensorsMap based on input name.
* @param name Node input name
* @param tensorsMap Tensors map keyed by the node
* @param context contains tensors and information for running the current node.
* @param resourceManager Optional. Contains global resources of the model.
*/
function getTensor(name, tensorsMap, context, resourceManager) {
var _parseNodeName = parseNodeName(name),
nodeName = _parseNodeName[0],
index = _parseNodeName[1];
if (resourceManager != null) {
var tensor = resourceManager.getHashTableHandleByName(nodeName);
if (tensor != null) {
return tensor;
}
}
var contextId = context.currentContextIds.find(function (contextId) {
return !!tensorsMap[getNodeNameWithContextId(nodeName, contextId)];
});
return contextId !== undefined ? tensorsMap[getNodeNameWithContextId(nodeName, contextId)][index] : undefined;
}
/**
* Retrieve the tensors based on input name for current context.
* @param name Node input name
* @param tensorsMap Tensors map keyed by the node
*/
function getTensorsForCurrentContenxt(name, tensorsMap, context) {
return tensorsMap[getNodeNameWithContextId(name, context.currentContextId)];
}
/**
* Returns the node name, outputName and index from the Node input name.
* @param inputName The input name of the node, in format of
* node_name:output_index, i.e. MatMul:0, if the output_index is not set, it is
* default to 0.
* If the input name contains output name i.e. StringSplit:indices:0, it will
* return ['StringSplit', 0, 'indices'].
*/
function getNodeNameAndIndex(inputName, context) {
var _parseNodeName2 = parseNodeName(inputName),
nodeName = _parseNodeName2[0],
index = _parseNodeName2[1],
outputName = _parseNodeName2[2];
return [getNodeNameWithContextId(nodeName, context && context.currentContextId), index, outputName];
}
function getNodeNameWithContextId(name, contextId) {
return !!contextId ? name + "-" + contextId : name;
}
function parseNodeName(name) {
var parts = name.split(':');
if (parts.length === 1) {
return [name, 0, undefined];
}
var nodeName = parts[0];
var outputName = parts.length === 3 ? parts[1] : undefined;
var index = Number(parts[parts.length - 1]);
return [nodeName, index, outputName];
}
function split$2(arr, size) {
var res = [];
for (var i = 0; i < arr.length; i += size) {
res.push(arr.slice(i, i + size));
}
return res;
}
function getPadding(node, tensorMap, context) {
var pad = getParamValue('pad', node, tensorMap, context);
if (pad === 'explicit') {
// This is 1d array, we need to convert it to 2d array
pad = getParamValue('explicitPaddings', node, tensorMap, context);
var explicitPadding = [[0, 0], [0, 0], [0, 0], [0, 0]];
for (var i = 0; i < 4; i++) {
explicitPadding[i][0] = pad[i * 2];
explicitPadding[i][1] = pad[i * 2 + 1];
}
return explicitPadding;
}
return pad;
}
/**
* Reuse the tensor if it is marked as keep, otherwise clone the tensor to
* avoid disposal. This is important for TensorArray and TensorList ops, since
* internally they use a tensor as the id for TensorArray and TensorList, and
* to simplify lookup, they also use Tensor.id as the key to the internal map.
* These id tensors have been marked as kept in the backend, we need avoid clone
* them in order to create new Tensor.id.
* @param tensor
*/
function cloneTensor(tensor) {
return tensor.kept ? tensor : clone(tensor);
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json = [{
'tfOpName': 'Add',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'AddV2',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'AddN',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'end': 0,
'name': 'tensors',
'type': 'tensors'
}]
}, {
'tfOpName': 'BiasAdd',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}]
}, {
'tfOpName': 'Sub',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'RealDiv',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Div',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'DivNoNan',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'FloorDiv',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Mul',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Maximum',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Minimum',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Pow',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'SquaredDifference',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Mod',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'FloorMod',
'category': 'arithmetic',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}];
var arithmetic = {
__proto__: null,
json: json
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$1 = [{
'tfOpName': 'Abs',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Acos',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Asin',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Atan',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Atan2',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'y',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Ceil',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'ClipByValue',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'clipValueMin',
'type': 'number'
}, {
'start': 2,
'name': 'clipValueMax',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Complex',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'real',
'type': 'tensor'
}, {
'start': 1,
'name': 'imag',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'ComplexAbs',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Cos',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Cosh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Elu',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Exp',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Floor',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Log',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Imag',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'outputType',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Neg',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Real',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'outputType',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Prelu',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'alpha',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Relu',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Relu6',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Selu',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Sigmoid',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Sin',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Sinh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Sqrt',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Rsqrt',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Square',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Tan',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Tanh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Sign',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Round',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Expm1',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Log1p',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Reciprocal',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Softplus',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Asinh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Acosh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Atanh',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Erf',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Prod',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axes',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool',
'notSupported': true
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LeakyRelu',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'alpha',
'name': 'alpha',
'type': 'number',
'defaultValue': 0.2
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'IsNan',
'category': 'basic_math',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}];
var basicMath = {
__proto__: null,
json: json$1
};
var json$2 = [{
'tfOpName': 'EmptyTensorList',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'elementShape',
'type': 'shape'
}, {
'start': 1,
'name': 'maxNumElements',
'type': 'number'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'LoopCond',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'pred',
'type': 'tensor'
}]
}, {
'tfOpName': 'Switch',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'data',
'type': 'tensor'
}, {
'start': 1,
'name': 'pred',
'type': 'tensor'
}]
}, {
'tfOpName': 'Merge',
'category': 'control',
'inputs': [{
'start': 0,
'end': 0,
'name': 'tensors',
'type': 'tensors'
}]
}, {
'tfOpName': 'Enter',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'frame_name',
'name': 'frameName',
'type': 'string'
}, {
'tfName': 'is_constant',
'name': 'isConstant',
'type': 'bool'
}]
}, {
'tfOpName': 'Exit',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'NextIteration',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'TensorArrayV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'size',
'type': 'number'
}],
'attrs': [{
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'element_shape',
'name': 'elementShape',
'type': 'shape'
}, {
'tfName': 'dynamic_size',
'name': 'dynamicSize',
'type': 'bool'
}, {
'tfName': 'clear_after_read',
'name': 'clearAfterRead',
'type': 'bool'
}, {
'tfName': 'identical_element_shapes',
'name': 'identicalElementShapes',
'type': 'bool'
}, {
'tfName': 'tensor_array_name',
'name': 'name',
'type': 'string'
}]
}, {
'tfOpName': 'TensorArrayWriteV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'index',
'type': 'number'
}, {
'start': 2,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 3,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'TensorArrayReadV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'index',
'type': 'number'
}, {
'start': 2,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'TensorArrayGatherV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'number[]'
}, {
'start': 2,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'element_shape',
'name': 'elementShape',
'type': 'shape'
}]
}, {
'tfOpName': 'TensorArrayScatterV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'number[]'
}, {
'start': 2,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 3,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorArrayConcatV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'element_shape_except0',
'name': 'elementShapeExcept0',
'type': 'shape',
'notSupported': true
}]
}, {
'tfOpName': 'TensorArraySplitV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 2,
'name': 'lengths',
'type': 'number[]'
}, {
'start': 3,
'name': 'flowIn',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorArraySizeV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}, {
'start': 1,
'name': 'flowIn',
'type': 'number'
}]
}, {
'tfOpName': 'TensorArrayCloseV3',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorArrayId',
'type': 'tensor'
}]
}, {
'tfOpName': 'StatelessIf',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'cond',
'type': 'tensor'
}, {
'start': 1,
'end': 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'then_branch',
'name': 'thenBranch',
'type': 'func'
}, {
'tfName': 'else_branch',
'name': 'elseBranch',
'type': 'func'
}]
}, {
'tfOpName': 'If',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'cond',
'type': 'tensor'
}, {
'start': 1,
'end': 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'then_branch',
'name': 'thenBranch',
'type': 'func'
}, {
'tfName': 'else_branch',
'name': 'elseBranch',
'type': 'func'
}]
}, {
'tfOpName': 'StatelessWhile',
'category': 'control',
'inputs': [{
'start': 0,
'end': 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'cond',
'name': 'cond',
'type': 'func'
}, {
'tfName': 'body',
'name': 'body',
'type': 'func'
}]
}, {
'tfOpName': 'While',
'category': 'control',
'inputs': [{
'start': 0,
'end': 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'cond',
'name': 'cond',
'type': 'func'
}, {
'tfName': 'body',
'name': 'body',
'type': 'func'
}]
}, {
'tfOpName': 'TensorListScatter',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'number[]'
}, {
'start': 2,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListScatterV2',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'number[]'
}, {
'start': 2,
'name': 'elementShape',
'type': 'shape'
}, {
'start': 3,
'name': 'numElements',
'type': 'number'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListGather',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'number[]'
}, {
'start': 2,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListGetItem',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'index',
'type': 'number'
}, {
'start': 2,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListSetItem',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'index',
'type': 'number'
}, {
'start': 2,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListReserve',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'elementShape',
'type': 'shape'
}, {
'start': 1,
'name': 'numElements',
'type': 'number'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListFromTensor',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 1,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListStack',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}, {
'tfName': 'num_elements',
'name': 'numElements',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListSplit',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}, {
'start': 1,
'name': 'elementShape',
'type': 'shape'
}, {
'start': 2,
'name': 'lengths',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListConcat',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'element_shape',
'name': 'elementShape',
'type': 'shape'
}, {
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListPopBack',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'elementShape',
'type': 'shape'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'TensorListPushBack',
'category': 'control',
'inputs': [{
'start': 0,
'name': 'tensorListId',
'type': 'tensor'
}, {
'start': 1,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'element_dtype',
'name': 'elementDType',
'type': 'dtype'
}]
}];
var control = {
__proto__: null,
json: json$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$3 = [{
'tfOpName': 'AvgPool',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}, {
'tfName': 'ksize',
'name': 'kernelSize',
'type': 'number[]'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'MaxPool',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}, {
'tfName': 'ksize',
'name': 'kernelSize',
'type': 'number[]'
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': [],
'notSupported': true
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'MaxPoolWithArgmax',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'ksize',
'name': 'kernelSize',
'type': 'number[]'
}, {
'tfName': 'include_batch_in_index',
'name': 'includeBatchInIndex',
'type': 'bool'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'AvgPool3D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}, {
'tfName': 'ksize',
'name': 'kernelSize',
'type': 'number[]'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'MaxPool3D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}, {
'tfName': 'ksize',
'name': 'kernelSize',
'type': 'number[]'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Conv1D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'stride',
'name': 'stride',
'type': 'number'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NWC'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'dilation',
'name': 'dilation',
'type': 'number',
'defaultValue': 1
}]
}, {
'tfOpName': 'Conv2D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'useCudnnOnGpu',
'name': 'useCudnnOnGpu',
'type': 'bool'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]'
}]
}, {
'tfOpName': '_FusedConv2D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}, {
'start': 2,
end: 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'num_args',
'name': 'numArgs',
'type': 'number'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}, {
'tfName': 'use_cudnn_on_gpu',
'name': 'useCudnnOnGpu',
'type': 'bool',
'defaultValue': true
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]',
'defaultValue': [1, 1, 1, 1]
}, {
'tfName': 'fused_ops',
'name': 'fusedOps',
'type': 'string[]',
'defaultValue': []
}, {
'tfName': 'epsilon',
'name': 'epsilon',
'type': 'number',
'defaultValue': 0.0001
}, {
'tfName': 'leakyrelu_alpha',
'name': 'leakyreluAlpha',
'type': 'number'
}]
}, {
'tfOpName': 'Conv2DBackpropInput',
'category': 'convolution',
'inputs': [{
'start': 2,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}, {
'start': 0,
'name': 'outputShape',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]',
'notSupported': true
}]
}, {
'tfOpName': 'DepthwiseConv2d',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'input',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]'
}]
}, {
'tfOpName': 'DepthwiseConv2dNative',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'input',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]'
}]
}, {
'tfOpName': 'FusedDepthwiseConv2dNative',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}, {
'start': 2,
end: 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'num_args',
'name': 'numArgs',
'type': 'number'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]',
'defaultValue': [1, 1, 1, 1]
}, {
'tfName': 'fused_ops',
'name': 'fusedOps',
'type': 'string[]',
'defaultValue': []
}, {
'tfName': 'explicit_paddings',
'name': 'explicitPaddings',
'type': 'number[]',
'defaultValue': []
}]
}, {
'tfOpName': 'Conv3D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'defaultValue': 'NHWC'
}, {
'tfName': 'dilations',
'name': 'dilations',
'type': 'number[]'
}]
}, {
'tfOpName': 'Dilation2D',
'category': 'convolution',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'filter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'strides',
'name': 'strides',
'type': 'number[]'
}, {
'tfName': 'rates',
'name': 'dilations',
'type': 'number[]'
}, {
'tfName': 'padding',
'name': 'pad',
'type': 'string'
}]
}];
var convolution = {
__proto__: null,
json: json$3
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$4 = [{
'tfOpName': 'Fill',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'shape',
'type': 'number[]'
}, {
'start': 1,
'name': 'value',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'LinSpace',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'start',
'type': 'number'
}, {
'start': 1,
'name': 'stop',
'type': 'number'
}, {
'start': 2,
'name': 'num',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'OneHot',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'indices',
'type': 'tensor'
}, {
'start': 1,
'name': 'depth',
'type': 'number'
}, {
'start': 2,
'name': 'onValue',
'type': 'number',
'defaultValue': 1
}, {
'start': 3,
'name': 'offValue',
'type': 'number',
'defaultValue': 0
}],
'attrs': [{
'tfName': 'axis',
'name': 'axis',
'type': 'number',
'notSupported': true
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Ones',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'shape',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'OnesLike',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'RandomUniform',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'shape',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'minval',
'name': 'minval',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'maxval',
'name': 'maxval',
'type': 'number',
'defaultValue': 1
}, {
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'seed',
'name': 'seed',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'seed2',
'name': 'seed2',
'type': 'number',
'defaultValue': 0,
'notSupported': true
}, {
'tfName': 'T',
'name': 'T',
'type': 'number',
'notSupported': true
}]
}, {
'tfOpName': 'Range',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'start',
'type': 'number'
}, {
'start': 1,
'name': 'stop',
'type': 'number'
}, {
'start': 2,
'name': 'step',
'type': 'number',
'defaultValue': 0
}],
'attrs': [{
'tfName': 'Tidx',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'TruncatedNormal',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'shape',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'means',
'name': 'mean',
'type': 'number',
'defaultValue': 0.0
}, {
'tfName': 'stddev',
'name': 'stdDev',
'type': 'number',
'defaultValue': 1.0
}, {
'tfName': 'seed',
'name': 'seed',
'type': 'number'
}, {
'tfName': 'seed2',
'name': 'seed2',
'type': 'number',
'defaultValue': 0,
'notSupported': true
}, {
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'T',
'name': 'T',
'type': 'number',
'notSupported': true
}]
}, {
'tfOpName': 'Zeros',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'shape',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'ZerosLike',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'Multinomial',
'category': 'creation',
'inputs': [{
'start': 0,
'name': 'logits',
'type': 'tensor'
}, {
'start': 1,
'name': 'numSamples',
'type': 'number'
}],
'attrs': [{
'tfName': 'seed',
'name': 'seed',
'type': 'number'
}, {
'tfName': 'seed2',
'name': 'seed2',
'type': 'number'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}, {
'tfName': 'output_dtype',
'name': 'output_dtype',
'type': 'dtype'
}]
}];
var creation = {
__proto__: null,
json: json$4
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$5 = [{
'tfOpName': 'NonMaxSuppressionV2',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'boxes',
'type': 'tensor'
}, {
'start': 1,
'name': 'scores',
'type': 'tensor'
}, {
'start': 2,
'name': 'maxOutputSize',
'type': 'number'
}, {
'start': 3,
'name': 'iouThreshold',
'type': 'number'
}]
}, {
'tfOpName': 'NonMaxSuppressionV3',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'boxes',
'type': 'tensor'
}, {
'start': 1,
'name': 'scores',
'type': 'tensor'
}, {
'start': 2,
'name': 'maxOutputSize',
'type': 'number'
}, {
'start': 3,
'name': 'iouThreshold',
'type': 'number'
}, {
'start': 4,
'name': 'scoreThreshold',
'type': 'number'
}]
}, {
'tfOpName': 'NonMaxSuppressionV4',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'boxes',
'type': 'tensor'
}, {
'start': 1,
'name': 'scores',
'type': 'tensor'
}, {
'start': 2,
'name': 'maxOutputSize',
'type': 'number'
}, {
'start': 3,
'name': 'iouThreshold',
'type': 'number'
}, {
'start': 4,
'name': 'scoreThreshold',
'type': 'number'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'T_threshold',
'name': 'threshold',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'pad_to_max_output_size',
'name': 'padToMaxOutputSize',
'type': 'bool'
}]
}, {
'tfOpName': 'NonMaxSuppressionV5',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'boxes',
'type': 'tensor'
}, {
'start': 1,
'name': 'scores',
'type': 'tensor'
}, {
'start': 2,
'name': 'maxOutputSize',
'type': 'number'
}, {
'start': 3,
'name': 'iouThreshold',
'type': 'number'
}, {
'start': 4,
'name': 'scoreThreshold',
'type': 'number'
}, {
'start': 5,
'name': 'softNmsSigma',
'type': 'number'
}]
}, {
'tfOpName': 'Where',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'condition',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'ListDiff',
'category': 'dynamic',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'y',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}];
var dynamic = {
__proto__: null,
json: json$5
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$6 = [{
'tfOpName': 'TopKV2',
'category': 'evaluation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'k',
'type': 'number'
}],
'attrs': [{
'tfName': 'sorted',
'name': 'sorted',
'type': 'bool'
}]
}, {
'tfOpName': 'Unique',
'category': 'evaluation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'UniqueV2',
'category': 'evaluation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number'
}]
}];
var evaluation = {
__proto__: null,
json: json$6
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$7 = [{
'tfOpName': 'PlaceholderWithDefault',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'default',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'shape',
'name': 'shape',
'type': 'shape'
}, {
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'Placeholder',
'category': 'graph',
'attrs': [{
'tfName': 'shape',
'name': 'shape',
'type': 'shape'
}, {
'tfName': 'dtype',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'Const',
'category': 'graph'
}, {
'tfOpName': 'Identity',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'IdentityN',
'category': 'graph',
'inputs': [{
'start': 0,
'end': 0,
'name': 'x',
'type': 'tensors'
}]
}, {
'tfOpName': 'Snapshot',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'Rank',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'Size',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'Shape',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'ShapeN',
'category': 'graph',
'inputs': [{
'start': 0,
'end': 0,
'name': 'x',
'type': 'tensors'
}]
}, {
'tfOpName': 'Print',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'data',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'message',
'name': 'message',
'type': 'string'
}, {
'tfName': 'first_n',
'name': 'firstN',
'type': 'number',
'notSupported': true
}, {
'tfName': 'summarize',
'name': 'summarize',
'type': 'number',
'defaultValue': 3
}]
}, {
'tfOpName': 'NoOp',
'category': 'graph',
'inputs': []
}, {
'tfOpName': 'StopGradient',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'FakeQuantWithMinMaxVars',
'category': 'graph',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'min',
'name': 'min',
'type': 'number'
}, {
'tfName': 'max',
'name': 'max',
'type': 'number'
}]
}];
var graph = {
__proto__: null,
json: json$7
};
var json$8 = [{
'tfOpName': 'HashTable',
'category': 'hash_table',
'inputs': [],
'attrs': [{
'tfName': 'shared_name',
'name': 'sharedName',
'type': 'string'
}, {
'tfName': 'use_node_name_sharing',
'name': 'useNodeNameSharing',
'type': 'bool'
}, {
'tfName': 'key_dtype',
'name': 'keyDType',
'type': 'dtype'
}, {
'tfName': 'value_dtype',
'name': 'valueDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'HashTableV2',
'category': 'hash_table',
'inputs': [],
'attrs': [{
'tfName': 'shared_name',
'name': 'sharedName',
'type': 'string'
}, {
'tfName': 'use_node_name_sharing',
'name': 'useNodeNameSharing',
'type': 'bool'
}, {
'tfName': 'key_dtype',
'name': 'keyDType',
'type': 'dtype'
}, {
'tfName': 'value_dtype',
'name': 'valueDType',
'type': 'dtype'
}]
}, {
'tfOpName': 'LookupTableImport',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}, {
'start': 1,
'name': 'keys',
'type': 'tensor'
}, {
'start': 2,
'name': 'values',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'Tin',
'name': 'tIn',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'tOut',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LookupTableImportV2',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}, {
'start': 1,
'name': 'keys',
'type': 'tensor'
}, {
'start': 2,
'name': 'values',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'Tin',
'name': 'tIn',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'tOut',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LookupTableFind',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}, {
'start': 1,
'name': 'keys',
'type': 'tensor'
}, {
'start': 2,
'name': 'defaultValue',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'Tin',
'name': 'tIn',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'tOut',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LookupTableFindV2',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}, {
'start': 1,
'name': 'keys',
'type': 'tensor'
}, {
'start': 2,
'name': 'defaultValue',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'Tin',
'name': 'tIn',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'Tout',
'name': 'tOut',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LookupTableSize',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}]
}, {
'tfOpName': 'LookupTableSizeV2',
'category': 'hash_table',
'inputs': [{
'start': 0,
'name': 'tableHandle',
'type': 'tensor'
}]
}];
var hashTable = {
__proto__: null,
json: json$8
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$9 = [{
'tfOpName': 'ResizeBilinear',
'category': 'image',
'inputs': [{
'start': 0,
'name': 'images',
'type': 'tensor'
}, {
'start': 1,
'name': 'size',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'align_corners',
'name': 'alignCorners',
'type': 'bool'
}, {
'tfName': 'half_pixel_centers',
'name': 'halfPixelCenters',
'type': 'bool'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'ResizeNearestNeighbor',
'category': 'image',
'inputs': [{
'start': 0,
'name': 'images',
'type': 'tensor'
}, {
'start': 1,
'name': 'size',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'align_corners',
'name': 'alignCorners',
'type': 'bool'
}, {
'tfName': 'half_pixel_centers',
'name': 'halfPixelCenters',
'type': 'bool'
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'CropAndResize',
'category': 'image',
'inputs': [{
'start': 0,
'name': 'image',
'type': 'tensor'
}, {
'start': 1,
'name': 'boxes',
'type': 'tensor'
}, {
'start': 2,
'name': 'boxInd',
'type': 'tensor'
}, {
'start': 3,
'name': 'cropSize',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'method',
'name': 'method',
'type': 'string'
}, {
'tfName': 'extrapolation_value',
'name': 'extrapolationValue',
'type': 'number'
}]
}];
var image$1 = {
__proto__: null,
json: json$9
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$a = [{
'tfOpName': 'Equal',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'NotEqual',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Greater',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'GreaterEqual',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Less',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LessEqual',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LogicalAnd',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LogicalNot',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'LogicalOr',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Select',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'condition',
'type': 'tensor'
}, {
'start': 1,
'name': 'a',
'type': 'tensor'
}, {
'start': 2,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'SelectV2',
'category': 'logical',
'inputs': [{
'start': 0,
'name': 'condition',
'type': 'tensor'
}, {
'start': 1,
'name': 'a',
'type': 'tensor'
}, {
'start': 2,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}];
var logical = {
__proto__: null,
json: json$a
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$b = [{
'tfOpName': '_FusedMatMul',
'category': 'matrices',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}, {
'start': 2,
end: 0,
'name': 'args',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'num_args',
'name': 'numArgs',
'type': 'number'
}, {
'tfName': 'fused_ops',
'name': 'fusedOps',
'type': 'string[]',
'defaultValue': []
}, {
'tfName': 'epsilon',
'name': 'epsilon',
'type': 'number',
'defaultValue': 0.0001
}, {
'tfName': 'transpose_a',
'name': 'transposeA',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'transpose_b',
'name': 'transposeB',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'MatMul',
'category': 'matrices',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'transpose_a',
'name': 'transposeA',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'transpose_b',
'name': 'transposeB',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'BatchMatMul',
'category': 'matrices',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'adj_x',
'name': 'transposeA',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'adj_y',
'name': 'transposeB',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'BatchMatMulV2',
'category': 'matrices',
'inputs': [{
'start': 0,
'name': 'a',
'type': 'tensor'
}, {
'start': 1,
'name': 'b',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'adj_x',
'name': 'transposeA',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'adj_y',
'name': 'transposeB',
'type': 'bool',
'defaultValue': false
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Transpose',
'category': 'matrices',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'perm',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'Einsum',
'category': 'matrices',
'inputs': [{
'start': 0,
'end': 0,
'name': 'tensors',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'equation',
'name': 'equation',
'type': 'string'
}, {
'tfName': 'N',
'name': 'n',
'type': 'number',
'defaultValue': 2
}, {
'tfName': 'T',
'name': 'dtype',
'type': 'dtype'
}]
}];
var matrices = {
__proto__: null,
json: json$b
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$c = [{
'tfOpName': 'FusedBatchNorm',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'scale',
'type': 'tensor'
}, {
'start': 2,
'name': 'offset',
'type': 'tensor'
}, {
'start': 3,
'name': 'mean',
'type': 'tensor'
}, {
'start': 4,
'name': 'variance',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'epsilon',
'name': 'epsilon',
'type': 'number',
'defaultValue': 0.001
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}]
}, {
'tfOpName': 'FusedBatchNormV2',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'scale',
'type': 'tensor'
}, {
'start': 2,
'name': 'offset',
'type': 'tensor'
}, {
'start': 3,
'name': 'mean',
'type': 'tensor'
}, {
'start': 4,
'name': 'variance',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'epsilon',
'name': 'epsilon',
'type': 'number',
'defaultValue': 0.001
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}]
}, {
'tfOpName': 'FusedBatchNormV3',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'scale',
'type': 'tensor'
}, {
'start': 2,
'name': 'offset',
'type': 'tensor'
}, {
'start': 3,
'name': 'mean',
'type': 'tensor'
}, {
'start': 4,
'name': 'variance',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'epsilon',
'name': 'epsilon',
'type': 'number',
'defaultValue': 0.001
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string',
'notSupported': true
}]
}, {
'tfOpName': 'LRN',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'depth_radius',
'name': 'radius',
'type': 'number',
'defaultValue': 5
}, {
'tfName': 'bias',
'name': 'bias',
'type': 'number',
'defaultValue': 1.0
}, {
'tfName': 'alpha',
'name': 'alpha',
'type': 'number',
'defaultValue': 1.0
}, {
'tfName': 'beta',
'name': 'beta',
'type': 'number',
'defaultValue': 0.5
}]
}, {
'tfOpName': 'Softmax',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'LogSoftmax',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'SparseToDense',
'category': 'normalization',
'inputs': [{
'start': 0,
'name': 'sparseIndices',
'type': 'tensor'
}, {
'start': 1,
'name': 'outputShape',
'type': 'number[]'
}, {
'start': 2,
'name': 'sparseValues',
'type': 'tensor'
}, {
'start': 3,
'name': 'defaultValue',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'validate_indices',
'name': 'validateIndices',
'type': 'bool',
'defaultValue': true,
'notSupported': true
}]
}];
var normalization = {
__proto__: null,
json: json$c
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$d = [{
'tfOpName': 'Bincount',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'size',
'type': 'number'
}, {
'start': 2,
'name': 'weights',
'type': 'tensor'
}]
}, {
'tfOpName': 'DenseBincount',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'size',
'type': 'number'
}, {
'start': 2,
'name': 'weights',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'binary_output',
'name': 'binaryOutput',
'type': 'bool'
}]
}, {
'tfOpName': 'Max',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'Mean',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'Min',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'Sum',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'All',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'Any',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'ArgMax',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number'
}]
}, {
'tfOpName': 'ArgMin',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number'
}]
}, {
'tfOpName': 'Prod',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'keep_dims',
'name': 'keepDims',
'type': 'bool'
}]
}, {
'tfOpName': 'Cumsum',
'category': 'reduction',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number'
}],
'attrs': [{
'tfName': 'exclusive',
'name': 'exclusive',
'type': 'bool'
}, {
'tfName': 'reverse',
'name': 'reverse',
'type': 'bool'
}]
}];
var reduction = {
__proto__: null,
json: json$d
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$e = [{
'tfOpName': 'ConcatV2',
'category': 'slice_join',
'inputs': [{
'start': 0,
'end': -1,
'name': 'tensors',
'type': 'tensors'
}, {
'start': -1,
'name': 'axis',
'type': 'number'
}],
'attrs': [{
'tfName': 'N',
'name': 'n',
'type': 'number',
'defaultValue': 2
}]
}, {
'tfOpName': 'Concat',
'category': 'slice_join',
'inputs': [{
'start': 1,
'end': 0,
'name': 'tensors',
'type': 'tensors'
}, {
'start': 0,
'name': 'axis',
'type': 'number'
}],
'attrs': [{
'tfName': 'N',
'name': 'n',
'type': 'number',
'defaultValue': 2
}]
}, {
'tfOpName': 'GatherV2',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'tensor'
}, {
'start': 2,
'name': 'axis',
'type': 'number',
'defaultValue': 0
}],
'attrs': [{
'tfName': 'batch_dims',
'name': 'batchDims',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'Gather',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'validate_indices',
'name': 'validateIndices',
'type': 'bool',
'notSupported': true
}]
}, {
'tfOpName': 'Reverse',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'dims',
'type': 'bool[]'
}]
}, {
'tfOpName': 'ReverseV2',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number[]'
}]
}, {
'tfOpName': 'Slice',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'begin',
'type': 'number[]'
}, {
'start': 2,
'name': 'size',
'type': 'number[]'
}]
}, {
'tfOpName': 'StridedSlice',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'begin',
'type': 'number[]'
}, {
'start': 2,
'name': 'end',
'type': 'number[]'
}, {
'start': 3,
'name': 'strides',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'begin_mask',
'name': 'beginMask',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'end_mask',
'name': 'endMask',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'new_axis_mask',
'name': 'newAxisMask',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'ellipsis_mask',
'name': 'ellipsisMask',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'shrink_axis_mask',
'name': 'shrinkAxisMask',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'Pack',
'category': 'slice_join',
'inputs': [{
'start': 0,
'end': 0,
'name': 'tensors',
'type': 'tensors'
}],
'attrs': [{
'tfName': 'axis',
'name': 'axis',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'Unpack',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'tensor',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'axis',
'name': 'axis',
'type': 'number',
'defaultValue': 0
}, {
'tfName': 'num',
'name': 'num',
'type': 'number',
'defaultValue': 0,
'notSupported': true
}]
}, {
'tfOpName': 'Tile',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'reps',
'type': 'number[]'
}]
}, {
'tfOpName': 'Split',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'axis',
'type': 'number',
'defaultValue': 0
}, {
'start': 1,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'num_split',
'name': 'numOrSizeSplits',
'type': 'number',
'defaultValue': 1
}]
}, {
'tfOpName': 'SplitV',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'numOrSizeSplits',
'type': 'number[]'
}, {
'start': 2,
'name': 'axis',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'ScatterNd',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'indices',
'type': 'tensor'
}, {
'start': 1,
'name': 'values',
'type': 'tensor'
}, {
'start': 2,
'name': 'shape',
'type': 'number[]'
}]
}, {
'tfOpName': 'GatherNd',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'tensor'
}]
}, {
'tfOpName': 'SparseToDense',
'category': 'slice_join',
'inputs': [{
'start': 0,
'name': 'sparseIndices',
'type': 'tensor'
}, {
'start': 1,
'name': 'outputShape',
'type': 'number[]'
}, {
'start': 2,
'name': 'sparseValues',
'type': 'tensor'
}, {
'start': 3,
'name': 'defaultValue',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'validate_indices',
'name': 'validateIndices',
'type': 'bool',
'defaultValue': false,
'notSupported': true
}]
}];
var sliceJoin = {
__proto__: null,
json: json$e
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$f = [{
'tfOpName': 'SparseFillEmptyRows',
'category': 'sparse',
'inputs': [{
'start': 0,
'name': 'indices',
'type': 'tensor'
}, {
'start': 1,
'name': 'values',
'type': 'tensor'
}, {
'start': 2,
'name': 'denseShape',
'type': 'tensor'
}, {
'start': 3,
'name': 'defaultValue',
'type': 'tensor'
}]
}, {
'tfOpName': 'SparseReshape',
'category': 'sparse',
'inputs': [{
'start': 0,
'name': 'inputIndices',
'type': 'tensor'
}, {
'start': 1,
'name': 'inputShape',
'type': 'tensor'
}, {
'start': 2,
'name': 'newShape',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'T',
'name': 'dtype',
'type': 'dtype',
'notSupported': true
}]
}, {
'tfOpName': 'SparseSegmentMean',
'category': 'sparse',
'inputs': [{
'start': 0,
'name': 'data',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'tensor'
}, {
'start': 2,
'name': 'segmentIds',
'type': 'tensor'
}]
}, {
'tfOpName': 'SparseSegmentSum',
'category': 'sparse',
'inputs': [{
'start': 0,
'name': 'data',
'type': 'tensor'
}, {
'start': 1,
'name': 'indices',
'type': 'tensor'
}, {
'start': 2,
'name': 'segmentIds',
'type': 'tensor'
}]
}];
var sparse$1 = {
__proto__: null,
json: json$f
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$g = [{
'tfOpName': 'FFT',
'category': 'spectral',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'IFFT',
'category': 'spectral',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}]
}, {
'tfOpName': 'RFFT',
'category': 'spectral',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'fft_length',
'type': 'number',
'notSupported': true
}]
}, {
'tfOpName': 'IRFFT',
'category': 'spectral',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'fft_length',
'type': 'number',
'notSupported': true
}]
}];
var spectral$1 = {
__proto__: null,
json: json$g
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$h = [{
'tfOpName': 'StringNGrams',
'category': 'string',
'inputs': [{
'start': 0,
'name': 'data',
'type': 'tensor'
}, {
'start': 1,
'name': 'dataSplits',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'separator',
'name': 'separator',
'type': 'string'
}, {
'tfName': 'ngram_widths',
'name': 'nGramWidths',
'type': 'number[]'
}, {
'tfName': 'left_pad',
'name': 'leftPad',
'type': 'string'
}, {
'tfName': 'right_pad',
'name': 'rightPad',
'type': 'string'
}, {
'tfName': 'pad_width',
'name': 'padWidth',
'type': 'number'
}, {
'tfName': 'preserve_short_sequences',
'name': 'preserveShortSequences',
'type': 'bool'
}],
'outputs': ['ngrams', 'ngrams_splits']
}, {
'tfOpName': 'StringSplit',
'category': 'string',
'inputs': [{
'start': 0,
'name': 'input',
'type': 'tensor'
}, {
'start': 1,
'name': 'delimiter',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'skip_empty',
'name': 'skipEmpty',
'type': 'bool'
}],
'outputs': ['indices', 'values', 'shape']
}, {
'tfOpName': 'StringToHashBucketFast',
'category': 'string',
'inputs': [{
'start': 0,
'name': 'input',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'num_buckets',
'name': 'numBuckets',
'type': 'number'
}]
}];
var string$1 = {
__proto__: null,
json: json$h
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var json$i = [{
'tfOpName': 'Cast',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'SrcT',
'name': 'sdtype',
'type': 'dtype',
'notSupported': true
}, {
'tfName': 'DstT',
'name': 'dtype',
'type': 'dtype'
}]
}, {
'tfOpName': 'ExpandDims',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'axis',
'type': 'number'
}]
}, {
'tfOpName': 'MirrorPad',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'padding',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'mode',
'name': 'mode',
'type': 'string'
}]
}, {
'tfOpName': 'Pad',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'padding',
'type': 'number[]'
}],
'attrs': [{
'tfName': 'constant_value',
'name': 'constantValue',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'PadV2',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'padding',
'type': 'number[]'
}, {
'start': 2,
'name': 'constantValue',
'type': 'number',
'defaultValue': 0
}]
}, {
'tfOpName': 'Reshape',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'shape',
'type': 'number[]'
}]
}, {
'tfOpName': 'Squeeze',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'axis',
'tfDeprecatedName': 'squeeze_dims',
'name': 'axis',
'type': 'number[]'
}]
}, {
'tfOpName': 'SpaceToBatchND',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'blockShape',
'type': 'number[]'
}, {
'start': 2,
'name': 'paddings',
'type': 'number[]'
}]
}, {
'tfOpName': 'BatchToSpaceND',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'blockShape',
'type': 'number[]'
}, {
'start': 2,
'name': 'crops',
'type': 'number[]'
}]
}, {
'tfOpName': 'DepthToSpace',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}],
'attrs': [{
'tfName': 'block_size',
'name': 'blockSize',
'type': 'number'
}, {
'tfName': 'data_format',
'name': 'dataFormat',
'type': 'string'
}]
}, {
'tfOpName': 'BroadcastTo',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 'x',
'type': 'tensor'
}, {
'start': 1,
'name': 'shape',
'type': 'number[]'
}],
'attrs': []
}, {
'tfOpName': 'BroadcastArgs',
'category': 'transformation',
'inputs': [{
'start': 0,
'name': 's0',
'type': 'tensor'
}, {
'start': 1,
'name': 's1',
'type': 'tensor'
}],
'attrs': []
}];
var transformation = {
__proto__: null,
json: json$i
};
var OperationMapper = /*#__PURE__*/function () {
// Loads the op mapping from the JSON file.
function OperationMapper() {
var _ref;
var ops = [arithmetic, basicMath, control, convolution, creation, dynamic, evaluation, graph, hashTable, image$1, logical, matrices, normalization, reduction, sliceJoin, sparse$1, spectral$1, string$1, transformation];
var mappersJson = (_ref = []).concat.apply(_ref, ops.map(function (op) {
return op.json;
}));
this.opMappers = mappersJson.reduce(function (map, mapper) {
map[mapper.tfOpName] = mapper;
return map;
}, {});
} // Converts the model inference graph from Tensorflow GraphDef to local
// representation for TensorFlow.js API
var _proto = OperationMapper.prototype;
_proto.transformGraph = function transformGraph(graph, signature) {
var _this = this;
if (signature === void 0) {
signature = {};
}
var tfNodes = graph.node;
var placeholders = [];
var weights = [];
var initNodes = [];
var nodes = tfNodes.reduce(function (map, node) {
map[node.name] = _this.mapNode(node);
if (node.op.startsWith('Placeholder')) {
placeholders.push(map[node.name]);
} else if (node.op === 'Const') {
weights.push(map[node.name]);
} else if (node.input == null || node.input.length === 0) {
initNodes.push(map[node.name]);
}
return map;
}, {});
var inputs = [];
var outputs = [];
var inputNodeNameToKey = {};
var outputNodeNameToKey = {};
if (signature != null) {
inputNodeNameToKey = this.mapSignatureEntries(signature.inputs);
outputNodeNameToKey = this.mapSignatureEntries(signature.outputs);
}
var allNodes = Object.keys(nodes);
allNodes.forEach(function (key) {
var node = nodes[key];
node.inputNames.forEach(function (name, index) {
var _getNodeNameAndIndex = getNodeNameAndIndex(name),
nodeName = _getNodeNameAndIndex[0],
outputName = _getNodeNameAndIndex[2];
var inputNode = nodes[nodeName];
if (inputNode.outputs != null) {
var outputIndex = inputNode.outputs.indexOf(outputName);
if (outputIndex !== -1) {
var inputName = nodeName + ":" + outputIndex; // update the input name to use the mapped output index directly.
node.inputNames[index] = inputName;
}
}
node.inputs.push(inputNode);
inputNode.children.push(node);
});
}); // if signature has not outputs set, add any node that does not have
// outputs.
if (Object.keys(outputNodeNameToKey).length === 0) {
allNodes.forEach(function (key) {
var node = nodes[key];
if (node.children.length === 0) {
outputs.push(node);
}
});
} else {
Object.keys(outputNodeNameToKey).forEach(function (name) {
var _getNodeNameAndIndex2 = getNodeNameAndIndex(name),
nodeName = _getNodeNameAndIndex2[0];
var node = nodes[nodeName];
if (node != null) {
node.signatureKey = outputNodeNameToKey[name];
outputs.push(node);
}
});
}
if (Object.keys(inputNodeNameToKey).length > 0) {
Object.keys(inputNodeNameToKey).forEach(function (name) {
var _getNodeNameAndIndex3 = getNodeNameAndIndex(name),
nodeName = _getNodeNameAndIndex3[0];
var node = nodes[nodeName];
if (node) {
node.signatureKey = inputNodeNameToKey[name];
inputs.push(node);
}
});
} else {
inputs = placeholders;
}
var functions = {};
if (graph.library != null && graph.library.function != null) {
functions = graph.library.function.reduce(function (functions, func) {
functions[func.signature.name] = _this.mapFunction(func);
return functions;
}, {});
}
var result = {
nodes: nodes,
inputs: inputs,
outputs: outputs,
weights: weights,
placeholders: placeholders,
signature: signature,
functions: functions
};
if (initNodes.length > 0) {
result.initNodes = initNodes;
}
return result;
};
_proto.mapSignatureEntries = function mapSignatureEntries(entries) {
return Object.keys(entries || {}).reduce(function (prev, curr) {
prev[entries[curr].name] = curr;
return prev;
}, {});
};
_proto.mapNode = function mapNode(node) {
// Unsupported ops will cause an error at run-time (not parse time), since
// they may not be used by the actual execution subgraph.
var mapper = getRegisteredOp(node.op) || this.opMappers[node.op] || {};
if (node.attr == null) {
node.attr = {};
}
var newNode = {
name: node.name,
op: node.op,
category: mapper.category,
inputNames: (node.input || []).map(function (input) {
return input.startsWith('^') ? input.substr(1) : input;
}),
inputs: [],
children: [],
inputParams: {},
attrParams: {},
rawAttrs: node.attr,
outputs: mapper.outputs
};
if (mapper.inputs != null) {
newNode.inputParams = mapper.inputs.reduce(function (map, param) {
map[param.name] = {
type: param.type,
inputIndexStart: param.start,
inputIndexEnd: param.end
};
return map;
}, {});
}
if (mapper.attrs != null) {
newNode.attrParams = mapper.attrs.reduce(function (map, param) {
var type = param.type;
var value = undefined;
switch (param.type) {
case 'string':
value = getStringParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getStringParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'string[]':
value = getStringArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getStringArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'number':
value = getNumberParam(node.attr, param.tfName, param.defaultValue || 0);
if (value === undefined && !!param.tfDeprecatedName) {
value = getNumberParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'number[]':
value = getNumericArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getNumericArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'bool':
value = getBoolParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getBoolParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'bool[]':
value = getBoolArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getBoolArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'shape':
value = getTensorShapeParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getTensorShapeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'shape[]':
value = getTensorShapeArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getTensorShapeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'dtype':
value = getDtypeParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getDtypeParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'dtype[]':
value = getDtypeArrayParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getDtypeArrayParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'func':
value = getFuncParam(node.attr, param.tfName, param.defaultValue);
if (value === undefined && !!param.tfDeprecatedName) {
value = getFuncParam(node.attr, param.tfDeprecatedName, param.defaultValue);
}
break;
case 'tensor':
case 'tensors':
break;
default:
throw new Error("Unsupported param type: " + param.type + " for op: " + node.op);
}
map[param.name] = {
value: value,
type: type
};
return map;
}, {});
}
return newNode;
} // map the TFunctionDef to TFJS graph object
;
_proto.mapFunction = function mapFunction(functionDef) {
var _this2 = this;
var tfNodes = functionDef.nodeDef;
var placeholders = [];
var weights = [];
var nodes = {};
if (tfNodes != null) {
nodes = tfNodes.reduce(function (map, node) {
map[node.name] = _this2.mapNode(node);
if (node.op === 'Const') {
weights.push(map[node.name]);
}
return map;
}, {});
}
var inputs = [];
var outputs = [];
functionDef.signature.inputArg.forEach(function (arg) {
var _getNodeNameAndIndex4 = getNodeNameAndIndex(arg.name),
nodeName = _getNodeNameAndIndex4[0];
var node = {
name: nodeName,
op: 'Placeholder',
inputs: [],
inputNames: [],
category: 'graph',
inputParams: {},
attrParams: {
dtype: {
value: parseDtypeParam(arg.type),
type: 'dtype'
}
},
children: []
};
node.signatureKey = arg.name;
inputs.push(node);
nodes[nodeName] = node;
});
var allNodes = Object.keys(nodes);
allNodes.forEach(function (key) {
var node = nodes[key];
node.inputNames.forEach(function (name, index) {
var _getNodeNameAndIndex5 = getNodeNameAndIndex(name),
nodeName = _getNodeNameAndIndex5[0],
outputName = _getNodeNameAndIndex5[2];
var inputNode = nodes[nodeName];
if (inputNode.outputs != null) {
var outputIndex = inputNode.outputs.indexOf(outputName);
if (outputIndex !== -1) {
var inputName = nodeName + ":" + outputIndex; // update the input name to use the mapped output index directly.
node.inputNames[index] = inputName;
}
}
node.inputs.push(inputNode);
inputNode.children.push(node);
});
});
var returnNodeMap = functionDef.ret;
functionDef.signature.outputArg.forEach(function (output) {
var _getNodeNameAndIndex6 = getNodeNameAndIndex(returnNodeMap[output.name]),
nodeName = _getNodeNameAndIndex6[0],
index = _getNodeNameAndIndex6[1];
var node = nodes[nodeName];
if (node != null) {
node.defaultOutput = index;
outputs.push(node);
}
});
var signature = this.mapArgsToSignature(functionDef);
return {
nodes: nodes,
inputs: inputs,
outputs: outputs,
weights: weights,
placeholders: placeholders,
signature: signature
};
};
_proto.mapArgsToSignature = function mapArgsToSignature(functionDef) {
var _this3 = this;
return {
methodName: functionDef.signature.name,
inputs: functionDef.signature.inputArg.reduce(function (map, arg) {
map[arg.name] = _this3.mapArgToTensorInfo(arg);
return map;
}, {}),
outputs: functionDef.signature.outputArg.reduce(function (map, arg) {
map[arg.name] = _this3.mapArgToTensorInfo(arg, functionDef.ret);
return map;
}, {})
};
};
_proto.mapArgToTensorInfo = function mapArgToTensorInfo(arg, nameMap) {
var name = arg.name;
if (nameMap != null) {
name = nameMap[name];
}
return {
name: name,
dtype: arg.type
};
};
_createClass(OperationMapper, null, [{
key: "Instance",
get: // Singleton instance for the mapper
function get() {
return this._instance || (this._instance = new this());
}
}]);
return OperationMapper;
}();
function decodeBase64(text) {
var global = env().global;
if (typeof global.atob !== 'undefined') {
return global.atob(text);
} else if (typeof Buffer !== 'undefined') {
return new Buffer(text, 'base64').toString();
} else {
throw new Error('Unable to decode base64 in this environment. ' + 'Missing built-in atob() or Buffer()');
}
}
function parseStringParam(s, keepCase) {
var value = Array.isArray(s) ? String.fromCharCode.apply(null, s) : decodeBase64(s);
return keepCase ? value : value.toLowerCase();
}
function getStringParam(attrs, name, def, keepCase) {
if (keepCase === void 0) {
keepCase = false;
}
var param = attrs[name];
if (param != null) {
return parseStringParam(param.s, keepCase);
}
return def;
}
function getBoolParam(attrs, name, def) {
var param = attrs[name];
return param ? param.b : def;
}
function getNumberParam(attrs, name, def) {
var param = attrs[name] || {};
var value = param['i'] != null ? param['i'] : param['f'] != null ? param['f'] : def;
return typeof value === 'number' ? value : parseInt(value, 10);
}
function parseDtypeParam(value) {
if (typeof value === 'string') {
// tslint:disable-next-line:no-any
value = DataType[value];
}
switch (value) {
case DataType.DT_FLOAT:
return 'float32';
case DataType.DT_INT32:
case DataType.DT_INT64:
case DataType.DT_INT8:
case DataType.DT_UINT8:
return 'int32';
case DataType.DT_BOOL:
return 'bool';
case DataType.DT_DOUBLE:
return 'float32';
case DataType.DT_STRING:
return 'string';
default:
// Unknown dtype error will happen at runtime (instead of parse time),
// since these nodes might not be used by the actual subgraph execution.
return null;
}
}
function getFuncParam(attrs, name, def) {
var param = attrs[name];
if (param && param.func) {
return param.func.name;
}
return def;
}
function getDtypeParam(attrs, name, def) {
var param = attrs[name];
if (param && param.type) {
return parseDtypeParam(param.type);
}
return def;
}
function getDtypeArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.type) {
return param.list.type.map(function (v) {
return parseDtypeParam(v);
});
}
return def;
}
function parseTensorShapeParam(shape) {
if (shape.unknownRank) {
return undefined;
}
if (shape.dim != null) {
return shape.dim.map(function (dim) {
return typeof dim.size === 'number' ? dim.size : parseInt(dim.size, 10);
});
}
return [];
}
function getTensorShapeParam(attrs, name, def) {
var param = attrs[name];
if (param && param.shape) {
return parseTensorShapeParam(param.shape);
}
return def;
}
function getNumericArrayParam(attrs, name, def) {
var param = attrs[name];
if (param) {
return ((param.list.f && param.list.f.length ? param.list.f : param.list.i) || []).map(function (v) {
return typeof v === 'number' ? v : parseInt(v, 10);
});
}
return def;
}
function getStringArrayParam(attrs, name, def, keepCase) {
if (keepCase === void 0) {
keepCase = false;
}
var param = attrs[name];
if (param && param.list && param.list.s) {
return param.list.s.map(function (v) {
return parseStringParam(v, keepCase);
});
}
return def;
}
function getTensorShapeArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.shape) {
return param.list.shape.map(function (v) {
return parseTensorShapeParam(v);
});
}
return def;
}
function getBoolArrayParam(attrs, name, def) {
var param = attrs[name];
if (param && param.list && param.list.b) {
return param.list.b;
}
return def;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Helper class for lookup inputs and params for nodes in the model graph.
*/
var NodeValueImpl = /*#__PURE__*/function () {
function NodeValueImpl(node, tensorMap, context) {
var _this = this;
this.node = node;
this.tensorMap = tensorMap;
this.context = context;
this.inputs = [];
this.attrs = {};
this.inputs = node.inputNames.map(function (name) {
return _this.getInput(name);
});
if (node.rawAttrs != null) {
this.attrs = Object.keys(node.rawAttrs).reduce(function (attrs, key) {
attrs[key] = _this.getAttr(key);
return attrs;
}, {});
}
}
/**
* Return the value of the attribute or input param.
* @param name String: name of attribute or input param.
*/
var _proto = NodeValueImpl.prototype;
_proto.getInput = function getInput(name) {
return getTensor(name, this.tensorMap, this.context);
}
/**
* Return the value of the attribute or input param.
* @param name String: name of attribute or input param.
*/
;
_proto.getAttr = function getAttr(name, defaultValue) {
var value = this.node.rawAttrs[name];
if (value.tensor != null) {
return getTensor(name, this.tensorMap, this.context);
}
if (value.i != null || value.f != null) {
return getNumberParam(this.node.rawAttrs, name, defaultValue);
}
if (value.s != null) {
return getStringParam(this.node.rawAttrs, name, defaultValue);
}
if (value.b != null) {
return getBoolParam(this.node.rawAttrs, name, defaultValue);
}
if (value.shape != null) {
return getTensorShapeParam(this.node.rawAttrs, name, defaultValue);
}
if (value.type != null) {
return getDtypeParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list != null) {
if (value.list.i != null || value.list.f != null) {
return getNumericArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.s != null) {
return getStringArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.shape != null) {
return getTensorShapeArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.b != null) {
return getBoolArrayParam(this.node.rawAttrs, name, defaultValue);
}
if (value.list.type != null) {
return getDtypeArrayParam(this.node.rawAttrs, name, defaultValue);
}
}
return defaultValue;
};
return NodeValueImpl;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tfOps = {
__proto__: null,
abs: abs$8,
acos: acos,
acosh: acosh,
add: add$1,
addN: addN,
all: all,
any: any,
argMax: argMax,
argMin: argMin,
asin: asin,
asinh: asinh$1,
atan: atan,
atan2: atan2,
atanh: atanh,
avgPool: avgPool,
avgPool3d: avgPool3d,
basicLSTMCell: basicLSTMCell,
batchToSpaceND: batchToSpaceND,
batchNorm: batchNorm,
batchNorm2d: batchNorm2d,
batchNorm3d: batchNorm3d,
batchNorm4d: batchNorm4d,
bincount: bincount,
broadcastArgs: broadcastArgs,
broadcastTo: broadcastTo,
buffer: buffer,
cast: cast,
ceil: ceil$3,
clipByValue: clipByValue,
clone: clone,
complex: complex,
concat: concat,
concat1d: concat1d,
concat2d: concat2d,
concat3d: concat3d,
concat4d: concat4d,
conv1d: conv1d,
conv2d: conv2d,
conv2dTranspose: conv2dTranspose,
conv3d: conv3d,
conv3dTranspose: conv3dTranspose,
cos: cos,
cosh: cosh,
cumsum: cumsum,
denseBincount: denseBincount,
depthToSpace: depthToSpace,
depthwiseConv2d: depthwiseConv2d,
diag: diag,
dilation2d: dilation2d,
div: div,
divNoNan: divNoNan,
dot: dot,
einsum: einsum,
elu: elu,
equal: equal,
erf: erf,
exp: exp$3,
expandDims: expandDims,
expm1: expm1,
eye: eye,
fill: fill,
floor: floor$a,
floorDiv: floorDiv,
gather: gather,
greater: greater,
greaterEqual: greaterEqual,
imag: imag,
isFinite: isFinite$1,
isInf: isInf,
isNaN: isNaN$1,
leakyRelu: leakyRelu,
less: less,
lessEqual: lessEqual,
linspace: linspace,
localResponseNormalization: localResponseNormalization,
log: log$a,
log1p: log1p,
logSigmoid: logSigmoid,
logSoftmax: logSoftmax,
logSumExp: logSumExp,
logicalAnd: logicalAnd,
logicalNot: logicalNot,
logicalOr: logicalOr,
logicalXor: logicalXor,
matMul: matMul,
max: max$5,
maxPool: maxPool,
maxPool3d: maxPool3d,
maxPoolWithArgmax: maxPoolWithArgmax,
maximum: maximum,
mean: mean,
meshgrid: meshgrid,
min: min$9,
minimum: minimum,
mirrorPad: mirrorPad,
mod: mod,
moments: moments,
mul: mul,
multiRNNCell: multiRNNCell,
multinomial: multinomial,
neg: neg,
notEqual: notEqual,
oneHot: oneHot,
ones: ones$1,
onesLike: onesLike,
outerProduct: outerProduct,
pad: pad,
pad1d: pad1d,
pad2d: pad2d,
pad3d: pad3d,
pad4d: pad4d,
pool: pool,
pow: pow$5,
prelu: prelu,
print: print,
prod: prod,
rand: rand,
randomGamma: randomGamma,
randomNormal: randomNormal,
randomUniform: randomUniform,
range: range,
real: real,
reciprocal: reciprocal,
relu: relu,
relu6: relu6,
reshape: reshape,
reverse: reverse,
reverse1d: reverse1d,
reverse2d: reverse2d,
reverse3d: reverse3d,
reverse4d: reverse4d,
round: round$1,
rsqrt: rsqrt,
scalar: scalar,
selu: selu,
separableConv2d: separableConv2d,
setdiff1dAsync: setdiff1dAsync,
sigmoid: sigmoid,
sign: sign,
sin: sin,
sinh: sinh,
slice: slice$2,
slice1d: slice1d,
slice2d: slice2d,
slice3d: slice3d,
slice4d: slice4d,
softmax: softmax,
softplus: softplus,
spaceToBatchND: spaceToBatchND,
fft: fft,
ifft: ifft,
irfft: irfft,
rfft: rfft,
split: split$1,
sqrt: sqrt$3,
square: square,
squaredDifference: squaredDifference,
squeeze: squeeze,
stack: stack,
step: step,
stridedSlice: stridedSlice,
sub: sub,
sum: sum$1,
tan: tan,
tanh: tanh$1,
tensor: tensor,
tensor1d: tensor1d,
tensor2d: tensor2d,
tensor3d: tensor3d,
tensor4d: tensor4d,
tensor5d: tensor5d,
tensor6d: tensor6d,
tile: tile,
topk: topk,
truncatedNormal: truncatedNormal,
unique: unique,
unsortedSegmentSum: unsortedSegmentSum,
unstack: unstack,
variable: variable,
where: where,
whereAsync: whereAsync,
zeros: zeros,
zerosLike: zerosLike,
op: op,
OP_SCOPE_SUFFIX: OP_SCOPE_SUFFIX,
booleanMaskAsync: booleanMaskAsync,
transpose: transpose,
norm: norm,
movingAverage: movingAverage,
scatterND: scatterND,
sparseToDense: sparseToDense,
gatherND: gatherND,
dropout: dropout,
enclosingPowerOfTwo: enclosingPowerOfTwo,
cosineWindow: cosineWindow,
inTopKAsync: inTopKAsync,
image: image,
linalg: linalg,
losses: losses,
spectral: spectral,
fused: fused_ops,
signal: signal,
sparse: sparse,
string: string
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'BiasAdd':
case 'AddV2':
case 'Add':
{
return [add$1(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'AddN':
{
return [addN(getParamValue('tensors', node, tensorMap, context))];
}
case 'FloorMod':
case 'Mod':
return [mod(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
case 'Mul':
return [mul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
case 'RealDiv':
case 'Div':
{
return [div(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'DivNoNan':
{
return [divNoNan(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'FloorDiv':
{
return [floorDiv(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Sub':
{
return [sub(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Minimum':
{
return [minimum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Maximum':
{
return [maximum(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Pow':
{
return [pow$5(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'SquaredDifference':
{
return [squaredDifference(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY = 'arithmetic';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$1 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Abs':
case 'ComplexAbs':
return [abs$8(getParamValue('x', node, tensorMap, context))];
case 'Acos':
return [acos(getParamValue('x', node, tensorMap, context))];
case 'Acosh':
return [acosh(getParamValue('x', node, tensorMap, context))];
case 'Asin':
return [asin(getParamValue('x', node, tensorMap, context))];
case 'Asinh':
return [asinh$1(getParamValue('x', node, tensorMap, context))];
case 'Atan':
return [atan(getParamValue('x', node, tensorMap, context))];
case 'Atan2':
return [atan2(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context))];
case 'Atanh':
return [atanh(getParamValue('x', node, tensorMap, context))];
case 'Ceil':
return [ceil$3(getParamValue('x', node, tensorMap, context))];
case 'Complex':
return [complex(getParamValue('real', node, tensorMap, context), getParamValue('imag', node, tensorMap, context))];
case 'Cos':
return [cos(getParamValue('x', node, tensorMap, context))];
case 'Cosh':
return [cosh(getParamValue('x', node, tensorMap, context))];
case 'Elu':
return [elu(getParamValue('x', node, tensorMap, context))];
case 'Erf':
return [erf(getParamValue('x', node, tensorMap, context))];
case 'Exp':
return [exp$3(getParamValue('x', node, tensorMap, context))];
case 'Expm1':
{
return [expm1(getParamValue('x', node, tensorMap, context))];
}
case 'Floor':
return [floor$a(getParamValue('x', node, tensorMap, context))];
case 'Log':
return [log$a(getParamValue('x', node, tensorMap, context))];
case 'Log1p':
{
return [log1p(getParamValue('x', node, tensorMap, context))];
}
case 'Imag':
return [imag(getParamValue('x', node, tensorMap, context))];
case 'Neg':
return [neg(getParamValue('x', node, tensorMap, context))];
case 'Reciprocal':
{
return [reciprocal(getParamValue('x', node, tensorMap, context))];
}
case 'Real':
return [real(getParamValue('x', node, tensorMap, context))];
case 'Relu':
return [relu(getParamValue('x', node, tensorMap, context))];
case 'Round':
{
return [round$1(getParamValue('x', node, tensorMap, context))];
}
case 'Selu':
return [selu(getParamValue('x', node, tensorMap, context))];
case 'Sigmoid':
return [sigmoid(getParamValue('x', node, tensorMap, context))];
case 'Sin':
return [sin(getParamValue('x', node, tensorMap, context))];
case 'Sign':
{
return [sign(getParamValue('x', node, tensorMap, context))];
}
case 'Sinh':
{
return [sinh(getParamValue('x', node, tensorMap, context))];
}
case 'Softplus':
{
return [softplus(getParamValue('x', node, tensorMap, context))];
}
case 'Sqrt':
{
return [sqrt$3(getParamValue('x', node, tensorMap, context))];
}
case 'Square':
{
return [square(getParamValue('x', node, tensorMap, context))];
}
case 'Tanh':
{
return [tanh$1(getParamValue('x', node, tensorMap, context))];
}
case 'Tan':
return [tan(getParamValue('x', node, tensorMap, context))];
case 'ClipByValue':
return [clipByValue(getParamValue('x', node, tensorMap, context), getParamValue('clipValueMin', node, tensorMap, context), getParamValue('clipValueMax', node, tensorMap, context))];
case 'Relu6':
return [relu6(getParamValue('x', node, tensorMap, context))];
case 'Rsqrt':
return [rsqrt(getTensor(node.inputNames[0], tensorMap, context))];
case 'Prod':
return [prod(getParamValue('x', node, tensorMap, context), getParamValue('axes', node, tensorMap, context))];
case 'LeakyRelu':
return [leakyRelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
case 'Prelu':
return [prelu(getParamValue('x', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context))];
case 'IsNan':
return [isNaN$1(getTensor(node.inputNames[0], tensorMap, context))];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$1 = 'basic_math';
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Used by TensorList and TensorArray to verify if elementShape matches, support
* negative value as the dim shape.
* @param shapeA
* @param shapeB
* @param errorMessagePrefix
*/
function assertShapesMatchAllowUndefinedSize(shapeA, shapeB, errorMessagePrefix) {
if (errorMessagePrefix === void 0) {
errorMessagePrefix = '';
}
// constant shape means unknown rank
if (typeof shapeA === 'number' || typeof shapeB === 'number') {
return;
}
assert(shapeA.length === shapeB.length, function () {
return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
for (var i = 0; i < shapeA.length; i++) {
var dim0 = shapeA[i];
var dim1 = shapeB[i];
assert(dim0 < 0 || dim1 < 0 || dim0 === dim1, function () {
return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match");
});
}
}
function fullDefinedShape(elementShape) {
if (typeof elementShape === 'number' || elementShape.some(function (dim) {
return dim < 0;
})) {
return false;
}
return true;
}
/**
* Generate the output element shape from the list elementShape, list tensors
* and input param.
* @param listElementShape
* @param tensors
* @param elementShape
*/
function inferElementShape(listElementShape, tensors, elementShape) {
var partialShape = mergeElementShape(listElementShape, elementShape);
var notfullDefinedShape = !fullDefinedShape(partialShape);
if (notfullDefinedShape && tensors.length === 0) {
throw new Error("Tried to calculate elements of an empty list" + (" with non-fully-defined elementShape: " + partialShape));
}
if (notfullDefinedShape) {
tensors.forEach(function (tensor) {
partialShape = mergeElementShape(tensor.shape, partialShape);
});
}
if (!fullDefinedShape(partialShape)) {
throw new Error("Non-fully-defined elementShape: " + partialShape);
}
return partialShape;
}
function mergeElementShape(elementShapeA, elementShapeB) {
if (typeof elementShapeA === 'number') {
return elementShapeB;
}
if (typeof elementShapeB === 'number') {
return elementShapeA;
}
if (elementShapeA.length !== elementShapeB.length) {
throw new Error("Incompatible ranks during merge: " + elementShapeA + " vs. " + elementShapeB);
}
var result = [];
for (var i = 0; i < elementShapeA.length; ++i) {
var dim0 = elementShapeA[i];
var dim1 = elementShapeB[i];
if (dim0 >= 0 && dim1 >= 0 && dim0 !== dim1) {
throw new Error("Incompatible shape during merge: " + elementShapeA + " vs. " + elementShapeB);
}
result[i] = dim0 >= 0 ? dim0 : dim1;
}
return result;
}
/**
* The TensorArray object keeps an array of Tensors. It
* allows reading from the array and writing to the array.
*/
var TensorArray = /*#__PURE__*/function () {
function TensorArray(name, dtype, maxSize, elementShape, identicalElementShapes, dynamicSize, clearAfterRead) {
this.name = name;
this.dtype = dtype;
this.maxSize = maxSize;
this.elementShape = elementShape;
this.identicalElementShapes = identicalElementShapes;
this.dynamicSize = dynamicSize;
this.clearAfterRead = clearAfterRead;
this.tensors = [];
this.closed_ = false;
this.idTensor = scalar(0);
keep(this.idTensor);
}
var _proto = TensorArray.prototype;
/**
* Dispose the tensors and idTensor and mark the TensoryArray as closed.
*/
_proto.clearAndClose = function clearAndClose(keepIds) {
this.tensors.forEach(function (tensor) {
if (keepIds == null || !keepIds.has(tensor.tensor.id)) {
tensor.tensor.dispose();
}
});
this.tensors = [];
this.closed_ = true;
this.idTensor.dispose();
};
_proto.size = function size() {
return this.tensors.length;
}
/**
* Read the value at location index in the TensorArray.
* @param index Number the index to read from.
*/
;
_proto.read = function read(index) {
if (this.closed_) {
throw new Error("TensorArray " + this.name + " has already been closed.");
}
if (index < 0 || index >= this.size()) {
throw new Error("Tried to read from index " + index + ", but array size is: " + this.size());
}
var tensorWithState = this.tensors[index];
if (tensorWithState.cleared) {
throw new Error("TensorArray " + this.name + ": Could not read index " + index + " twice because it was cleared after a previous read " + "(perhaps try setting clear_after_read = false?).");
}
if (this.clearAfterRead) {
tensorWithState.cleared = true;
}
tensorWithState.read = true;
return tensorWithState.tensor;
}
/**
* Helper method to read multiple tensors from the specified indices.
*/
;
_proto.readMany = function readMany(indices) {
var _this = this;
return indices.map(function (index) {
return _this.read(index);
});
}
/**
* Write value into the index of the TensorArray.
* @param index number the index to write to.
* @param tensor
*/
;
_proto.write = function write(index, tensor) {
if (this.closed_) {
throw new Error("TensorArray " + this.name + " has already been closed.");
}
if (index < 0 || !this.dynamicSize && index >= this.maxSize) {
throw new Error("Tried to write to index " + index + ", but array is not resizeable and size is: " + this.maxSize);
}
var t = this.tensors[index] || {};
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ",\n because the value dtype is " + tensor.dtype + ", but TensorArray dtype is " + this.dtype + ".");
} // Set the shape for the first time write to unknow shape tensor array
if (this.size() === 0 && (this.elementShape == null || this.elementShape.length === 0)) {
this.elementShape = tensor.shape;
}
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, "TensorArray " + this.name + ": Could not write to TensorArray index " + index + ".");
if (t.read) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ", because it has already been read.");
}
if (t.written) {
throw new Error("TensorArray " + this.name + ": Could not write to TensorArray index " + index + ", because it has already been written.");
}
t.tensor = tensor;
keep(tensor);
t.written = true;
this.tensors[index] = t;
}
/**
* Helper method to write multiple tensors to the specified indices.
*/
;
_proto.writeMany = function writeMany(indices, tensors) {
var _this2 = this;
if (indices.length !== tensors.length) {
throw new Error("TensorArray " + this.name + ": could not write multiple tensors," + ("because the index size: " + indices.length + " is not the same as tensors size: " + tensors.length + "."));
}
indices.forEach(function (i, index) {
return _this2.write(i, tensors[index]);
});
}
/**
* Return selected values in the TensorArray as a packed Tensor. All of
* selected values must have been written and their shapes must all match.
* @param [indices] number[] Optional. Taking values in [0, max_value). If the
* TensorArray is not dynamic, max_value=size(). If not specified returns
* all tensors in the original order.
* @param [dtype]
*/
;
_proto.gather = function gather(indices, dtype) {
if (!!dtype && dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but gather requested dtype " + dtype);
}
if (!indices) {
indices = [];
for (var i = 0; i < this.size(); i++) {
indices.push(i);
}
} else {
indices = indices.slice(0, this.size());
}
if (indices.length === 0) {
return tensor([], [0].concat(this.elementShape));
} // Read all the PersistentTensors into a vector to keep track of
// their memory.
var tensors = this.readMany(indices);
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');
return stack(tensors, 0);
}
/**
* Return the values in the TensorArray as a concatenated Tensor.
*/
;
_proto.concat = function concat$1(dtype) {
if (!!dtype && dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but concat requested dtype " + dtype);
}
if (this.size() === 0) {
return tensor([], [0].concat(this.elementShape));
}
var indices = [];
for (var i = 0; i < this.size(); i++) {
indices.push(i);
} // Collect all the tensors from the tensors array.
var tensors = this.readMany(indices);
assertShapesMatchAllowUndefinedSize(this.elementShape, tensors[0].shape, "TensorArray shape mismatch: tensor array shape (" + this.elementShape + ") vs first tensor shape (" + tensors[0].shape + ")");
return concat(tensors, 0);
}
/**
* Scatter the values of a Tensor in specific indices of a TensorArray.
* @param indices nummber[] values in [0, max_value). If the
* TensorArray is not dynamic, max_value=size().
* @param tensor Tensor input tensor.
*/
;
_proto.scatter = function scatter(indices, tensor) {
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but tensor has dtype " + tensor.dtype);
}
if (indices.length !== tensor.shape[0]) {
throw new Error("Expected len(indices) == tensor.shape[0], but saw: " + indices.length + " vs. " + tensor.shape[0]);
}
var maxIndex = Math.max.apply(Math, indices);
if (!this.dynamicSize && maxIndex >= this.maxSize) {
throw new Error("Max index must be < array size (" + maxIndex + " vs. " + this.maxSize + ")");
}
this.writeMany(indices, unstack(tensor, 0));
}
/**
* Split the values of a Tensor into the TensorArray.
* @param length number[] with the lengths to use when splitting value along
* its first dimension.
* @param tensor Tensor, the tensor to split.
*/
;
_proto.split = function split(length, tensor) {
var _this3 = this;
if (tensor.dtype !== this.dtype) {
throw new Error("TensorArray dtype is " + this.dtype + " but tensor has dtype " + tensor.dtype);
}
var totalLength = 0;
var cumulativeLengths = length.map(function (len) {
totalLength += len;
return totalLength;
});
if (totalLength !== tensor.shape[0]) {
throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n " + totalLength + ", and tensor's shape is: " + tensor.shape);
}
if (!this.dynamicSize && length.length !== this.maxSize) {
throw new Error("TensorArray's size is not equal to the size of lengths (" + this.maxSize + " vs. " + length.length + "), " + 'and the TensorArray is not marked as dynamically resizeable');
}
var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
var tensors = [];
tidy(function () {
tensor = reshape(tensor, [1, totalLength, elementPerRow]);
for (var i = 0; i < length.length; ++i) {
var previousLength = i === 0 ? 0 : cumulativeLengths[i - 1];
var _indices = [0, previousLength, 0];
var sizes = [1, length[i], elementPerRow];
tensors[i] = reshape(slice$2(tensor, _indices, sizes), _this3.elementShape);
}
return tensors;
});
var indices = [];
for (var i = 0; i < length.length; i++) {
indices[i] = i;
}
this.writeMany(indices, tensors);
};
_createClass(TensorArray, [{
key: "id",
get: function get() {
return this.idTensor.id;
}
}, {
key: "closed",
get: function get() {
return this.closed_;
}
}]);
return TensorArray;
}();
/**
* TensorList stores a container of `tf.Tensor` objects, which are accessible
* via tensors field.
*
* In order to get a copy of the underlying list, use the copy method:
* ```
* TensorList b = a.copy();
* b.tensors().pushBack(t); // This does not modify a.tensors().
* ```
*
* Note that this is not a deep copy: the memory locations of the underlying
* tensors will still point to the same locations of the corresponding tensors
* in the original.
*/
var TensorList = /*#__PURE__*/function () {
/**
*
* @param tensors list of tensors
* @param elementShape shape of each tensor, this can be a single number (any
* shape is allowed) or partial shape (dim = -1).
* @param elementDtype data type of each tensor
* @param maxNumElements The maximum allowed size of `tensors`. Defaults to -1
* meaning that the size of `tensors` is unbounded.
*/
function TensorList(tensors, elementShape, elementDtype, maxNumElements) {
if (maxNumElements === void 0) {
maxNumElements = -1;
}
this.tensors = tensors;
this.elementShape = elementShape;
this.elementDtype = elementDtype;
if (tensors != null) {
tensors.forEach(function (tensor) {
if (elementDtype !== tensor.dtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + tensor.dtype);
}
assertShapesMatchAllowUndefinedSize(elementShape, tensor.shape, 'TensorList shape mismatch: ');
keep(tensor);
});
}
this.idTensor = scalar(0);
this.maxNumElements = maxNumElements;
keep(this.idTensor);
}
var _proto = TensorList.prototype;
/**
* Get a new TensorList containing a copy of the underlying tensor container.
*/
_proto.copy = function copy() {
return new TensorList([].concat(this.tensors), this.elementShape, this.elementDtype);
}
/**
* Dispose the tensors and idTensor and clear the tensor list.
*/
;
_proto.clearAndClose = function clearAndClose(keepIds) {
this.tensors.forEach(function (tensor) {
if (keepIds == null || !keepIds.has(tensor.id)) {
tensor.dispose();
}
});
this.tensors.length = 0;
this.idTensor.dispose();
}
/**
* The size of the tensors in the tensor list.
*/
;
_proto.size = function size() {
return this.tensors.length;
}
/**
* Return a tensor that stacks a list of rank-R tf.Tensors into one rank-(R+1)
* tf.Tensor.
* @param elementShape shape of each tensor
* @param elementDtype data type of each tensor
* @param numElements the number of elements to stack
*/
;
_proto.stack = function stack$1(elementShape, elementDtype, numElements) {
var _this = this;
if (numElements === void 0) {
numElements = -1;
}
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (numElements !== -1 && this.tensors.length !== numElements) {
throw new Error("Operation expected a list with " + numElements + " elements but got a list with " + this.tensors.length + " elements.");
}
assertShapesMatchAllowUndefinedSize(elementShape, this.elementShape, 'TensorList shape mismatch: ');
var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
return tidy(function () {
var reshapedTensors = _this.tensors.map(function (tensor) {
return reshape(tensor, outputElementShape);
});
return stack(reshapedTensors, 0);
});
}
/**
* Pop a tensor from the end of the list.
* @param elementShape shape of the tensor
* @param elementDtype data type of the tensor
*/
;
_proto.popBack = function popBack(elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (this.size() === 0) {
throw new Error('Trying to pop from an empty list.');
}
var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
var tensor = this.tensors.pop();
assertShapesMatchAllowUndefinedSize(tensor.shape, elementShape, 'TensorList shape mismatch: ');
return reshape(tensor, outputElementShape);
}
/**
* Push a tensor to the end of the list.
* @param tensor Tensor to be pushed.
*/
;
_proto.pushBack = function pushBack(tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + this.elementDtype);
}
assertShapesMatchAllowUndefinedSize(tensor.shape, this.elementShape, 'TensorList shape mismatch: ');
if (this.maxNumElements === this.size()) {
throw new Error("Trying to push element into a full list.");
}
keep(tensor);
this.tensors.push(tensor);
}
/**
* Update the size of the list.
* @param size the new size of the list.
*/
;
_proto.resize = function resize(size) {
if (size < 0) {
throw new Error("TensorListResize expects size to be non-negative. Got: " + size);
}
if (this.maxNumElements !== -1 && size > this.maxNumElements) {
throw new Error("TensorListResize input size " + size + " is greater maxNumElement " + this.maxNumElements + ".");
}
this.tensors.length = size;
}
/**
* Retrieve the element at the provided index
* @param elementShape shape of the tensor
* @param elementDtype dtype of the tensor
* @param elementIndex index of the tensor
*/
;
_proto.getItem = function getItem(elementIndex, elementShape, elementDtype) {
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
if (elementIndex < 0 || elementIndex > this.tensors.length) {
throw new Error("Trying to access element " + elementIndex + " in a list with " + this.tensors.length + " elements.");
}
if (this.tensors[elementIndex] == null) {
throw new Error("element at index " + elementIndex + " is null.");
}
assertShapesMatchAllowUndefinedSize(this.tensors[elementIndex].shape, elementShape, 'TensorList shape mismatch: ');
var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
return reshape(this.tensors[elementIndex], outputElementShape);
}
/**
* Set the tensor at the index
* @param elementIndex index of the tensor
* @param tensor the tensor to be inserted into the list
*/
;
_proto.setItem = function setItem(elementIndex, tensor) {
if (tensor.dtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + this.elementDtype);
}
if (elementIndex < 0 || this.maxNumElements !== -1 && elementIndex >= this.maxNumElements) {
throw new Error("Trying to set element " + elementIndex + " in a list with max " + this.maxNumElements + " elements.");
}
assertShapesMatchAllowUndefinedSize(this.elementShape, tensor.shape, 'TensorList shape mismatch: ');
keep(tensor);
this.tensors[elementIndex] = tensor;
}
/**
* Return selected values in the TensorList as a stacked Tensor. All of
* selected values must have been written and their shapes must all match.
* @param indices indices of tensors to gather
* @param elementDtype output tensor dtype
* @param elementShape output tensor element shape
*/
;
_proto.gather = function gather(indices, elementDtype, elementShape) {
var _this2 = this;
if (elementDtype !== this.elementDtype) {
throw new Error("Invalid data types; op elements " + elementDtype + ", but list elements " + this.elementDtype);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: '); // When indices is greater than the size of the list, indices beyond the
// size of the list are ignored.
indices = indices.slice(0, this.size());
var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
if (indices.length === 0) {
return tensor([], [0].concat(outputElementShape));
}
return tidy(function () {
var tensors = indices.map(function (i) {
return reshape(_this2.tensors[i], outputElementShape);
});
return stack(tensors, 0);
});
}
/**
* Return the values in the TensorList as a concatenated Tensor.
* @param elementDtype output tensor dtype
* @param elementShape output tensor element shape
*/
;
_proto.concat = function concat$1(elementDtype, elementShape) {
var _this3 = this;
if (!!elementDtype && elementDtype !== this.elementDtype) {
throw new Error("TensorList dtype is " + this.elementDtype + " but concat requested dtype " + elementDtype);
}
assertShapesMatchAllowUndefinedSize(this.elementShape, elementShape, 'TensorList shape mismatch: ');
var outputElementShape = inferElementShape(this.elementShape, this.tensors, elementShape);
if (this.size() === 0) {
return tensor([], [0].concat(outputElementShape));
}
return tidy(function () {
var tensors = _this3.tensors.map(function (t) {
return reshape(t, outputElementShape);
});
return concat(tensors, 0);
});
};
_createClass(TensorList, [{
key: "id",
get: function get() {
return this.idTensor.id;
}
}]);
return TensorList;
}();
/**
* Creates a TensorList which, when stacked, has the value of tensor.
* @param tensor from tensor
* @param elementShape output tensor element shape
*/
function fromTensor(tensor, elementShape, elementDtype) {
var dtype = tensor.dtype;
if (tensor.shape.length < 1) {
throw new Error("Tensor must be at least a vector, but saw shape: " + tensor.shape);
}
if (tensor.dtype !== elementDtype) {
throw new Error("Invalid data types; op elements " + tensor.dtype + ", but list elements " + elementDtype);
}
var tensorElementShape = tensor.shape.slice(1);
assertShapesMatchAllowUndefinedSize(tensorElementShape, elementShape, 'TensorList shape mismatch: ');
var tensorList = unstack(tensor);
return new TensorList(tensorList, elementShape, dtype);
}
/**
* Return a TensorList of the given size with empty elements.
* @param elementShape the shape of the future elements of the list
* @param elementDtype the desired type of elements in the list
* @param numElements the number of elements to reserve
*/
function reserve(elementShape, elementDtype, numElements) {
return new TensorList([], elementShape, elementDtype, numElements);
}
/**
* Put tensors at specific indices of a stacked tensor into a TensorList.
* @param indices list of indices on how to scatter the tensor.
* @param tensor input tensor.
* @param elementShape the shape of the future elements of the list
* @param numElements the number of elements to scatter
*/
function scatter(tensor, indices, elementShape, numElements) {
if (indices.length !== tensor.shape[0]) {
throw new Error("Expected len(indices) == tensor.shape[0], but saw: " + indices.length + " vs. " + tensor.shape[0]);
}
var maxIndex = Math.max.apply(Math, indices);
if (numElements != null && numElements !== -1 && maxIndex >= numElements) {
throw new Error("Max index must be < array size (" + maxIndex + " vs. " + numElements + ")");
}
var list = new TensorList([], elementShape, tensor.dtype, numElements);
var tensors = unstack(tensor, 0);
indices.forEach(function (value, index) {
list.setItem(value, tensors[index]);
});
return list;
}
/**
* Split the values of a Tensor into a TensorList.
* @param length the lengths to use when splitting value along
* its first dimension.
* @param tensor the tensor to split.
* @param elementShape the shape of the future elements of the list
*/
function split$3(tensor, length, elementShape) {
var totalLength = 0;
var cumulativeLengths = length.map(function (len) {
totalLength += len;
return totalLength;
});
if (totalLength !== tensor.shape[0]) {
throw new Error("Expected sum of lengths to be equal to\n tensor.shape[0], but sum of lengths is\n " + totalLength + ", and tensor's shape is: " + tensor.shape);
}
var shapeWithoutFirstDim = tensor.shape.slice(1);
var outputElementShape = mergeElementShape(shapeWithoutFirstDim, elementShape);
var elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
var tensors = tidy(function () {
var tensors = [];
tensor = reshape(tensor, [1, totalLength, elementPerRow]);
for (var i = 0; i < length.length; ++i) {
var previousLength = i === 0 ? 0 : cumulativeLengths[i - 1];
var indices = [0, previousLength, 0];
var sizes = [1, length[i], elementPerRow];
tensors[i] = reshape(slice$2(tensor, indices, sizes), outputElementShape);
}
tensor.dispose();
return tensors;
});
var list = new TensorList([], elementShape, tensor.dtype, length.length);
for (var i = 0; i < tensors.length; i++) {
list.setItem(i, tensors[i]);
}
return list;
}
var executeOp$2 = /*#__PURE__*/function () {
var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(node, tensorMap, context) {
var thenFunc, elseFunc, cond, args, condValue, _ret, pred, _pred, data, inputName, _data, frameId, _data2, _data3, _data4, size, dtype, elementShape, dynamicSize, clearAfterRead, identicalElementShapes, name, tensorArray, id, index, writeTensor, writeTensorArray, readId, readIndex, readTensorArray, gatherId, gatherIndices, gatherDtype, gatherTensorArray, scatterId, scatterIndices, scatterTensor, scatterTensorArray, concatId, concatTensorArray, concatDtype, splitId, splitTensor, lengths, splitTensorArray, sizeId, sizeTensorArray, closeId, closeTensorArray, idTensor, _index, _writeTensor, tensorList, _idTensor, _readIndex, _elementShape, elementDType, _tensorList, _scatterIndices, _scatterTensor, _elementShape2, numElements, _tensorList2, _elementShape3, elementDtype, numElementsParam, _numElements, _tensorList3, _gatherId, _gatherIndices, _elementShape4, _elementDtype, _tensorList4, _idTensor2, _elementShape5, _elementDtype2, _numElements2, _tensorList5, tensor, _elementShape6, _elementDtype3, _tensorList6, _concatId, _tensorList7, _concatDtype, _elementShape7, _idTensor3, _writeTensor2, _tensorList8, _idTensor4, _elementShape8, _elementDType, _tensorList9, _splitTensor, _elementShape9, _lengths, _tensorList10;
return regeneratorRuntime.wrap(function _callee2$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
_context3.t0 = node.op;
_context3.next = _context3.t0 === 'If' ? 3 : _context3.t0 === 'StatelessIf' ? 3 : _context3.t0 === 'While' ? 15 : _context3.t0 === 'StatelessWhile' ? 15 : _context3.t0 === 'LoopCond' ? 19 : _context3.t0 === 'Switch' ? 21 : _context3.t0 === 'Merge' ? 32 : _context3.t0 === 'Enter' ? 37 : _context3.t0 === 'Exit' ? 41 : _context3.t0 === 'NextIteration' ? 44 : _context3.t0 === 'TensorArrayV3' ? 47 : _context3.t0 === 'TensorArrayWriteV3' ? 57 : _context3.t0 === 'TensorArrayReadV3' ? 63 : _context3.t0 === 'TensorArrayGatherV3' ? 67 : _context3.t0 === 'TensorArrayScatterV3' ? 72 : _context3.t0 === 'TensorArrayConcatV3' ? 78 : _context3.t0 === 'TensorArraySplitV3' ? 82 : _context3.t0 === 'TensorArraySizeV3' ? 88 : _context3.t0 === 'TensorArrayCloseV3' ? 91 : _context3.t0 === 'TensorListSetItem' ? 95 : _context3.t0 === 'TensorListGetItem' ? 101 : _context3.t0 === 'TensorListScatterV2' ? 107 : _context3.t0 === 'TensorListScatter' ? 107 : _context3.t0 === 'TensorListReserve' ? 114 : _context3.t0 === 'EmptyTensorList' ? 114 : _context3.t0 === 'TensorListGather' ? 121 : _context3.t0 === 'TensorListStack' ? 127 : _context3.t0 === 'TensorListFromTensor' ? 133 : _context3.t0 === 'TensorListConcat' ? 139 : _context3.t0 === 'TensorListPushBack' ? 144 : _context3.t0 === 'TensorListPopBack' ? 149 : _context3.t0 === 'TensorListSplit' ? 154 : 160;
break;
case 3:
thenFunc = getParamValue('thenBranch', node, tensorMap, context);
elseFunc = getParamValue('elseBranch', node, tensorMap, context);
cond = getParamValue('cond', node, tensorMap, context);
args = getParamValue('args', node, tensorMap, context);
_context3.next = 9;
return cond.data();
case 9:
condValue = _context3.sent;
if (!condValue[0]) {
_context3.next = 14;
break;
}
return _context3.abrupt("return", context.functionMap[thenFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
case 14:
return _context3.abrupt("return", context.functionMap[elseFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap));
case 15:
return _context3.delegateYield( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var bodyFunc, condFunc, args, condResult, argIds, condValue, result, _loop;
return regeneratorRuntime.wrap(function _callee$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
bodyFunc = getParamValue('body', node, tensorMap, context);
condFunc = getParamValue('cond', node, tensorMap, context);
args = getParamValue('args', node, tensorMap, context); // Calculate the condition of the loop
_context2.next = 5;
return context.functionMap[condFunc].executeFunctionAsync(args, context.tensorArrayMap, context.tensorListMap);
case 5:
condResult = _context2.sent;
argIds = args.map(function (tensor) {
return tensor.id;
});
_context2.next = 9;
return condResult[0].data();
case 9:
condValue = _context2.sent;
// Dispose the intermediate tensors for condition function
condResult.forEach(function (tensor) {
if (!tensor.kept && argIds.indexOf(tensor.id) === -1) {
tensor.dispose();
}
});
result = args;
_loop = /*#__PURE__*/regeneratorRuntime.mark(function _loop() {
var origResult, resultIds, condResult;
return regeneratorRuntime.wrap(function _loop$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
// Record the previous result for intermediate tensor tracking
origResult = result; // Execution the body of the loop
_context.next = 3;
return context.functionMap[bodyFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
case 3:
result = _context.sent;
resultIds = result.map(function (tensor) {
return tensor.id;
}); // Dispose the intermediate tensor for body function that is not global
// kept, not input/output of the body function
origResult.forEach(function (tensor) {
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && resultIds.indexOf(tensor.id) === -1) {
tensor.dispose();
}
}); // Recalcuate the condition of the loop using the latest results.
_context.next = 8;
return context.functionMap[condFunc].executeFunctionAsync(result, context.tensorArrayMap, context.tensorListMap);
case 8:
condResult = _context.sent;
_context.next = 11;
return condResult[0].data();
case 11:
condValue = _context.sent;
// Dispose the intermediate tensors for condition function
condResult.forEach(function (tensor) {
if (!tensor.kept && argIds.indexOf(tensor.id) === -1 && resultIds.indexOf(tensor.id) === -1) {
tensor.dispose();
}
});
case 13:
case "end":
return _context.stop();
}
}
}, _loop);
});
case 13:
if (!condValue[0]) {
_context2.next = 17;
break;
}
return _context2.delegateYield(_loop(), "t0", 15);
case 15:
_context2.next = 13;
break;
case 17:
return _context2.abrupt("return", {
v: result
});
case 18:
case "end":
return _context2.stop();
}
}
}, _callee);
})(), "t1", 16);
case 16:
_ret = _context3.t1;
if (!(typeof _ret === "object")) {
_context3.next = 19;
break;
}
return _context3.abrupt("return", _ret.v);
case 19:
pred = getParamValue('pred', node, tensorMap, context);
return _context3.abrupt("return", [cloneTensor(pred)]);
case 21:
_pred = getParamValue('pred', node, tensorMap, context);
data = getParamValue('data', node, tensorMap, context);
if (!data.kept) {
data = cloneTensor(data);
} // Outputs nodes :0 => false, :1 => true
_context3.next = 26;
return _pred.data();
case 26:
if (!_context3.sent[0]) {
_context3.next = 30;
break;
}
_context3.t2 = [undefined, data];
_context3.next = 31;
break;
case 30:
_context3.t2 = [data, undefined];
case 31:
return _context3.abrupt("return", _context3.t2);
case 32:
inputName = node.inputNames.find(function (name) {
return getTensor(name, tensorMap, context) !== undefined;
});
if (!inputName) {
_context3.next = 36;
break;
}
_data = getTensor(inputName, tensorMap, context);
return _context3.abrupt("return", [cloneTensor(_data)]);
case 36:
return _context3.abrupt("return", undefined);
case 37:
frameId = getParamValue('frameName', node, tensorMap, context);
_data2 = getParamValue('tensor', node, tensorMap, context);
context.enterFrame(frameId);
return _context3.abrupt("return", [cloneTensor(_data2)]);
case 41:
_data3 = getParamValue('tensor', node, tensorMap, context);
context.exitFrame();
return _context3.abrupt("return", [cloneTensor(_data3)]);
case 44:
_data4 = getParamValue('tensor', node, tensorMap, context);
context.nextIteration();
return _context3.abrupt("return", [cloneTensor(_data4)]);
case 47:
size = getParamValue('size', node, tensorMap, context);
dtype = getParamValue('dtype', node, tensorMap, context);
elementShape = getParamValue('elementShape', node, tensorMap, context);
dynamicSize = getParamValue('dynamicSize', node, tensorMap, context);
clearAfterRead = getParamValue('clearAfterRead', node, tensorMap, context);
identicalElementShapes = getParamValue('identicalElementShapes', node, tensorMap, context);
name = getParamValue('name', node, tensorMap, context);
tensorArray = new TensorArray(name, dtype, size, elementShape, identicalElementShapes, dynamicSize, clearAfterRead);
context.addTensorArray(tensorArray);
return _context3.abrupt("return", [tensorArray.idTensor, scalar(1.0)]);
case 57:
id = getParamValue('tensorArrayId', node, tensorMap, context);
index = getParamValue('index', node, tensorMap, context);
writeTensor = getParamValue('tensor', node, tensorMap, context);
writeTensorArray = context.getTensorArray(id.id);
writeTensorArray.write(index, writeTensor);
return _context3.abrupt("return", [writeTensorArray.idTensor]);
case 63:
readId = getParamValue('tensorArrayId', node, tensorMap, context);
readIndex = getParamValue('index', node, tensorMap, context);
readTensorArray = context.getTensorArray(readId.id);
return _context3.abrupt("return", [readTensorArray.read(readIndex)]);
case 67:
gatherId = getParamValue('tensorArrayId', node, tensorMap, context);
gatherIndices = getParamValue('indices', node, tensorMap, context);
gatherDtype = getParamValue('dtype', node, tensorMap, context);
gatherTensorArray = context.getTensorArray(gatherId.id);
return _context3.abrupt("return", [gatherTensorArray.gather(gatherIndices, gatherDtype)]);
case 72:
scatterId = getParamValue('tensorArrayId', node, tensorMap, context);
scatterIndices = getParamValue('indices', node, tensorMap, context);
scatterTensor = getParamValue('tensor', node, tensorMap, context);
scatterTensorArray = context.getTensorArray(scatterId.id);
scatterTensorArray.scatter(scatterIndices, scatterTensor);
return _context3.abrupt("return", [scatterTensorArray.idTensor]);
case 78:
concatId = getParamValue('tensorArrayId', node, tensorMap, context);
concatTensorArray = context.getTensorArray(concatId.id);
concatDtype = getParamValue('dtype', node, tensorMap, context);
return _context3.abrupt("return", [concatTensorArray.concat(concatDtype)]);
case 82:
splitId = getParamValue('tensorArrayId', node, tensorMap, context);
splitTensor = getParamValue('tensor', node, tensorMap, context);
lengths = getParamValue('lengths', node, tensorMap, context);
splitTensorArray = context.getTensorArray(splitId.id);
splitTensorArray.split(lengths, splitTensor);
return _context3.abrupt("return", [splitTensorArray.idTensor]);
case 88:
sizeId = getParamValue('tensorArrayId', node, tensorMap, context);
sizeTensorArray = context.getTensorArray(sizeId.id);
return _context3.abrupt("return", [scalar(sizeTensorArray.size(), 'int32')]);
case 91:
closeId = getParamValue('tensorArrayId', node, tensorMap, context);
closeTensorArray = context.getTensorArray(closeId.id);
closeTensorArray.clearAndClose();
return _context3.abrupt("return", [closeTensorArray.idTensor]);
case 95:
idTensor = getParamValue('tensorListId', node, tensorMap, context);
_index = getParamValue('index', node, tensorMap, context);
_writeTensor = getParamValue('tensor', node, tensorMap, context);
tensorList = context.getTensorList(idTensor.id);
tensorList.setItem(_index, _writeTensor);
return _context3.abrupt("return", [tensorList.idTensor]);
case 101:
_idTensor = getParamValue('tensorListId', node, tensorMap, context);
_readIndex = getParamValue('index', node, tensorMap, context);
_elementShape = getParamValue('elementShape', node, tensorMap, context);
elementDType = getParamValue('elementDType', node, tensorMap, context);
_tensorList = context.getTensorList(_idTensor.id);
return _context3.abrupt("return", [_tensorList.getItem(_readIndex, _elementShape, elementDType)]);
case 107:
_scatterIndices = getParamValue('indices', node, tensorMap, context);
_scatterTensor = getParamValue('tensor', node, tensorMap, context);
_elementShape2 = getParamValue('elementShape', node, tensorMap, context);
numElements = getParamValue('numElements', node, tensorMap, context);
_tensorList2 = scatter(_scatterTensor, _scatterIndices, _elementShape2, numElements);
context.addTensorList(_tensorList2);
return _context3.abrupt("return", [_tensorList2.idTensor]);
case 114:
_elementShape3 = getParamValue('elementShape', node, tensorMap, context);
elementDtype = getParamValue('elementDType', node, tensorMap, context);
if (node.op === 'TensorListReserve') {
numElementsParam = 'numElements';
} else {
numElementsParam = 'maxNumElements';
}
_numElements = getParamValue(numElementsParam, node, tensorMap, context);
_tensorList3 = reserve(_elementShape3, elementDtype, _numElements);
context.addTensorList(_tensorList3);
return _context3.abrupt("return", [_tensorList3.idTensor]);
case 121:
_gatherId = getParamValue('tensorListId', node, tensorMap, context);
_gatherIndices = getParamValue('indices', node, tensorMap, context);
_elementShape4 = getParamValue('elementShape', node, tensorMap, context);
_elementDtype = getParamValue('elementDType', node, tensorMap, context);
_tensorList4 = context.getTensorList(_gatherId.id);
return _context3.abrupt("return", [_tensorList4.gather(_gatherIndices, _elementDtype, _elementShape4)]);
case 127:
_idTensor2 = getParamValue('tensorListId', node, tensorMap, context);
_elementShape5 = getParamValue('elementShape', node, tensorMap, context);
_elementDtype2 = getParamValue('elementDType', node, tensorMap, context);
_numElements2 = getParamValue('numElements', node, tensorMap, context);
_tensorList5 = context.getTensorList(_idTensor2.id);
return _context3.abrupt("return", [_tensorList5.stack(_elementShape5, _elementDtype2, _numElements2)]);
case 133:
tensor = getParamValue('tensor', node, tensorMap, context);
_elementShape6 = getParamValue('elementShape', node, tensorMap, context);
_elementDtype3 = getParamValue('elementDType', node, tensorMap, context);
_tensorList6 = fromTensor(tensor, _elementShape6, _elementDtype3);
context.addTensorList(_tensorList6);
return _context3.abrupt("return", [_tensorList6.idTensor]);
case 139:
_concatId = getParamValue('tensorListId', node, tensorMap, context);
_tensorList7 = context.getTensorList(_concatId.id);
_concatDtype = getParamValue('dtype', node, tensorMap, context);
_elementShape7 = getParamValue('elementShape', node, tensorMap, context);
return _context3.abrupt("return", [_tensorList7.concat(_concatDtype, _elementShape7)]);
case 144:
_idTensor3 = getParamValue('tensorListId', node, tensorMap, context);
_writeTensor2 = getParamValue('tensor', node, tensorMap, context);
_tensorList8 = context.getTensorList(_idTensor3.id);
_tensorList8.pushBack(_writeTensor2);
return _context3.abrupt("return", [_tensorList8.idTensor]);
case 149:
_idTensor4 = getParamValue('tensorListId', node, tensorMap, context);
_elementShape8 = getParamValue('elementShape', node, tensorMap, context);
_elementDType = getParamValue('elementDType', node, tensorMap, context);
_tensorList9 = context.getTensorList(_idTensor4.id);
return _context3.abrupt("return", [_tensorList9.popBack(_elementShape8, _elementDType)]);
case 154:
_splitTensor = getParamValue('tensor', node, tensorMap, context);
_elementShape9 = getParamValue('elementShape', node, tensorMap, context);
_lengths = getParamValue('lengths', node, tensorMap, context);
_tensorList10 = split$3(_splitTensor, _lengths, _elementShape9);
context.addTensorList(_tensorList10);
return _context3.abrupt("return", [_tensorList10.idTensor]);
case 160:
throw TypeError("Node type " + node.op + " is not implemented");
case 161:
case "end":
return _context3.stop();
}
}
}, _callee2);
}));
return function executeOp(_x, _x2, _x3) {
return _ref.apply(this, arguments);
};
}();
var CATEGORY$2 = 'control';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConvAndDepthWiseParams(node, tensorMap, context) {
var _getParamValue = getParamValue('fusedOps', node, tensorMap, context),
extraOp = _getParamValue[0],
activationFunc = _getParamValue[1];
var isBiasAdd = extraOp === 'biasadd';
var noBiasAdd = !isBiasAdd;
var isPrelu = activationFunc === 'prelu';
var isBatchNorm = extraOp === 'fusedbatchnorm';
var numArgs = getParamValue('numArgs', node, tensorMap, context);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd and Prelu ' + 'must have two extra arguments: bias and alpha.');
}
if (!isPrelu && isBiasAdd && numArgs !== 1) {
throw new Error('FusedConv2d and DepthwiseConv2d with BiasAdd must have ' + 'one extra argument: bias.');
}
}
if (isBatchNorm) {
throw new Error('FusedConv2d and DepthwiseConv2d with FusedBatchNorm is not supported');
}
var stride = getParamValue('strides', node, tensorMap, context);
var pad = getPadding(node, tensorMap, context);
var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
var dilations = getParamValue('dilations', node, tensorMap, context);
var _getParamValue2 = getParamValue('args', node, tensorMap, context),
biasArg = _getParamValue2[0],
preluArg = _getParamValue2[1];
if (noBiasAdd) {
preluArg = biasArg;
biasArg = undefined;
}
var leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
return {
stride: stride,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
biasArg: biasArg,
preluArg: preluArg,
activationFunc: activationFunc,
leakyreluAlpha: leakyreluAlpha
};
}
var executeOp$3 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Conv1D':
{
var stride = getParamValue('stride', node, tensorMap, context);
var pad = getParamValue('pad', node, tensorMap, context);
var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
var dilation = getParamValue('dilation', node, tensorMap, context);
return [conv1d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), stride, pad, dataFormat, dilation)];
}
case 'Conv2D':
{
var _stride = getParamValue('strides', node, tensorMap, context);
var _pad = getPadding(node, tensorMap, context);
var _dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
var dilations = getParamValue('dilations', node, tensorMap, context);
return [conv2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride[1], _stride[2]], _pad, _dataFormat, [dilations[1], dilations[2]])];
}
case '_FusedConv2D':
{
var _fusedConvAndDepthWis = fusedConvAndDepthWiseParams(node, tensorMap, context),
_stride2 = _fusedConvAndDepthWis.stride,
_pad2 = _fusedConvAndDepthWis.pad,
_dataFormat2 = _fusedConvAndDepthWis.dataFormat,
_dilations = _fusedConvAndDepthWis.dilations,
biasArg = _fusedConvAndDepthWis.biasArg,
preluArg = _fusedConvAndDepthWis.preluArg,
activationFunc = _fusedConvAndDepthWis.activationFunc,
leakyreluAlpha = _fusedConvAndDepthWis.leakyreluAlpha;
return [conv2d$1({
x: getParamValue('x', node, tensorMap, context),
filter: getParamValue('filter', node, tensorMap, context),
strides: [_stride2[1], _stride2[2]],
pad: _pad2,
dataFormat: _dataFormat2,
dilations: [_dilations[1], _dilations[2]],
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg,
leakyreluAlpha: leakyreluAlpha
})];
}
case 'FusedDepthwiseConv2dNative':
{
var _fusedConvAndDepthWis2 = fusedConvAndDepthWiseParams(node, tensorMap, context),
_stride3 = _fusedConvAndDepthWis2.stride,
_pad3 = _fusedConvAndDepthWis2.pad,
_dataFormat3 = _fusedConvAndDepthWis2.dataFormat,
_dilations2 = _fusedConvAndDepthWis2.dilations,
_biasArg = _fusedConvAndDepthWis2.biasArg,
_preluArg = _fusedConvAndDepthWis2.preluArg,
_activationFunc = _fusedConvAndDepthWis2.activationFunc,
_leakyreluAlpha = _fusedConvAndDepthWis2.leakyreluAlpha;
return [depthwiseConv2d$1({
x: getParamValue('x', node, tensorMap, context),
filter: getParamValue('filter', node, tensorMap, context),
strides: [_stride3[1], _stride3[2]],
pad: _pad3,
dataFormat: _dataFormat3,
dilations: [_dilations2[1], _dilations2[2]],
bias: _biasArg,
activation: _activationFunc,
preluActivationWeights: _preluArg,
leakyreluAlpha: _leakyreluAlpha
})];
}
case 'Conv2DBackpropInput':
case 'Conv2dTranspose':
{
var shape = getParamValue('outputShape', node, tensorMap, context);
var _stride4 = getParamValue('strides', node, tensorMap, context);
var _pad4 = getPadding(node, tensorMap, context);
return [conv2dTranspose(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), shape, [_stride4[1], _stride4[2]], _pad4)];
}
case 'DepthwiseConv2dNative':
case 'DepthwiseConv2d':
{
var _stride5 = getParamValue('strides', node, tensorMap, context);
var _pad5 = getPadding(node, tensorMap, context);
var _dilations3 = getParamValue('dilations', node, tensorMap, context);
var _dataFormat4 = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
return [depthwiseConv2d(getParamValue('input', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride5[1], _stride5[2]], _pad5, _dataFormat4, [_dilations3[1], _dilations3[2]])];
}
case 'Conv3D':
{
var _stride6 = getParamValue('strides', node, tensorMap, context);
var _pad6 = getParamValue('pad', node, tensorMap, context);
var _dataFormat5 = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
var _dilations4 = getParamValue('dilations', node, tensorMap, context);
return [conv3d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [_stride6[1], _stride6[2], _stride6[3]], _pad6, _dataFormat5, [_dilations4[1], _dilations4[2], _dilations4[3]])];
}
case 'AvgPool':
{
var _stride7 = getParamValue('strides', node, tensorMap, context);
var _pad7 = getParamValue('pad', node, tensorMap, context);
var kernelSize = getParamValue('kernelSize', node, tensorMap, context);
return [avgPool(getParamValue('x', node, tensorMap, context), [kernelSize[1], kernelSize[2]], [_stride7[1], _stride7[2]], _pad7)];
}
case 'MaxPool':
{
var _stride8 = getParamValue('strides', node, tensorMap, context);
var _pad8 = getParamValue('pad', node, tensorMap, context);
var _kernelSize = getParamValue('kernelSize', node, tensorMap, context);
return [maxPool(getParamValue('x', node, tensorMap, context), [_kernelSize[1], _kernelSize[2]], [_stride8[1], _stride8[2]], _pad8)];
}
case 'MaxPoolWithArgmax':
{
var _stride9 = getParamValue('strides', node, tensorMap, context);
var _pad9 = getParamValue('pad', node, tensorMap, context);
var _kernelSize2 = getParamValue('kernelSize', node, tensorMap, context);
var includeBatchInIndex = getParamValue('includeBatchInIndex', node, tensorMap, context);
var _tfOps$maxPoolWithArg = maxPoolWithArgmax(getParamValue('x', node, tensorMap, context), [_kernelSize2[1], _kernelSize2[2]], [_stride9[1], _stride9[2]], _pad9, includeBatchInIndex),
result = _tfOps$maxPoolWithArg.result,
indexes = _tfOps$maxPoolWithArg.indexes;
return [result, indexes];
}
case 'AvgPool3D':
{
var _stride10 = getParamValue('strides', node, tensorMap, context);
var _pad10 = getParamValue('pad', node, tensorMap, context);
var _kernelSize3 = getParamValue('kernelSize', node, tensorMap, context);
return [avgPool3d(getParamValue('x', node, tensorMap, context), [_kernelSize3[1], _kernelSize3[2], _kernelSize3[3]], [_stride10[1], _stride10[2], _stride10[3]], _pad10)];
}
case 'MaxPool3D':
{
var _stride11 = getParamValue('strides', node, tensorMap, context);
var _pad11 = getParamValue('pad', node, tensorMap, context);
var _kernelSize4 = getParamValue('kernelSize', node, tensorMap, context);
return [maxPool3d(getParamValue('x', node, tensorMap, context), [_kernelSize4[1], _kernelSize4[2], _kernelSize4[3]], [_stride11[1], _stride11[2], _stride11[3]], _pad11)];
}
case 'Dilation2D':
{
var strides = getParamValue('strides', node, tensorMap, context);
var _pad12 = getParamValue('pad', node, tensorMap, context);
var _dilations5 = getParamValue('dilations', node, tensorMap, context); // strides: [1, stride_height, stride_width, 1].
var strideHeight = strides[1];
var strideWidth = strides[2]; // dilations: [1, dilation_height, dilation_width, 1].
var dilationHeight = _dilations5[1];
var dilationWidth = _dilations5[2];
return [dilation2d(getParamValue('x', node, tensorMap, context), getParamValue('filter', node, tensorMap, context), [strideHeight, strideWidth], _pad12, [dilationHeight, dilationWidth], 'NHWC'
/* dataFormat */
)];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$3 = 'convolution';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$4 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Fill':
{
var shape = getParamValue('shape', node, tensorMap, context);
var dtype = getParamValue('dtype', node, tensorMap, context);
var value = getParamValue('value', node, tensorMap, context);
return [fill(shape, value, dtype)];
}
case 'LinSpace':
{
var start = getParamValue('start', node, tensorMap, context);
var stop = getParamValue('stop', node, tensorMap, context);
var num = getParamValue('num', node, tensorMap, context);
return [linspace(start, stop, num)];
}
case 'Multinomial':
{
var logits = getParamValue('logits', node, tensorMap, context);
var numSamples = getParamValue('numSamples', node, tensorMap, context);
var seed = getParamValue('seed', node, tensorMap, context);
return [multinomial(logits, numSamples, seed)];
}
case 'OneHot':
{
var indices = getParamValue('indices', node, tensorMap, context);
var depth = getParamValue('depth', node, tensorMap, context);
var onValue = getParamValue('onValue', node, tensorMap, context);
var offValue = getParamValue('offValue', node, tensorMap, context);
return [oneHot(indices, depth, onValue, offValue)];
}
case 'Ones':
{
return [ones$1(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
}
case 'OnesLike':
{
return [onesLike(getParamValue('x', node, tensorMap, context))];
}
case 'RandomUniform':
{
return [randomUniform( // tslint:disable-next-line:no-any
getParamValue('shape', node, tensorMap, context), getParamValue('minval', node, tensorMap, context), getParamValue('maxval', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
}
case 'Range':
{
var _start = getParamValue('start', node, tensorMap, context);
var _stop = getParamValue('stop', node, tensorMap, context);
var step = getParamValue('step', node, tensorMap, context);
return [range(_start, _stop, step, getParamValue('dtype', node, tensorMap, context))];
}
case 'TruncatedNormal':
{
var _shape = getParamValue('shape', node, tensorMap, context);
var mean = getParamValue('mean', node, tensorMap, context);
var stdDev = getParamValue('stdDev', node, tensorMap, context);
var _seed = getParamValue('seed', node, tensorMap, context);
return [truncatedNormal(_shape, mean, stdDev, getParamValue('dtype', node, tensorMap, context), _seed)];
}
case 'Zeros':
{
return [zeros(getParamValue('shape', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
}
case 'ZerosLike':
{
return [zerosLike(getParamValue('x', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$4 = 'creation';
function nmsParams(node, tensorMap, context) {
var boxes = getParamValue('boxes', node, tensorMap, context);
var scores = getParamValue('scores', node, tensorMap, context);
var maxOutputSize = getParamValue('maxOutputSize', node, tensorMap, context);
var iouThreshold = getParamValue('iouThreshold', node, tensorMap, context);
var scoreThreshold = getParamValue('scoreThreshold', node, tensorMap, context);
var softNmsSigma = getParamValue('softNmsSigma', node, tensorMap, context);
return {
boxes: boxes,
scores: scores,
maxOutputSize: maxOutputSize,
iouThreshold: iouThreshold,
scoreThreshold: scoreThreshold,
softNmsSigma: softNmsSigma
};
}
var executeOp$5 = /*#__PURE__*/function () {
var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(node, tensorMap, context) {
var _nmsParams, boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, result, _nmsParams2, _boxes, _scores, _maxOutputSize, _iouThreshold, _scoreThreshold, padToMaxOutputSize, _result, _nmsParams3, _boxes2, _scores2, _maxOutputSize2, _iouThreshold2, _scoreThreshold2, condition, _result2;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.t0 = node.op;
_context.next = _context.t0 === 'NonMaxSuppressionV5' ? 3 : _context.t0 === 'NonMaxSuppressionV4' ? 8 : _context.t0 === 'NonMaxSuppressionV3' ? 14 : _context.t0 === 'NonMaxSuppressionV2' ? 14 : _context.t0 === 'Where' ? 19 : _context.t0 === 'ListDiff' ? 26 : 27;
break;
case 3:
_nmsParams = nmsParams(node, tensorMap, context), boxes = _nmsParams.boxes, scores = _nmsParams.scores, maxOutputSize = _nmsParams.maxOutputSize, iouThreshold = _nmsParams.iouThreshold, scoreThreshold = _nmsParams.scoreThreshold, softNmsSigma = _nmsParams.softNmsSigma;
_context.next = 6;
return image.nonMaxSuppressionWithScoreAsync(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma);
case 6:
result = _context.sent;
return _context.abrupt("return", [result.selectedIndices, result.selectedScores]);
case 8:
_nmsParams2 = nmsParams(node, tensorMap, context), _boxes = _nmsParams2.boxes, _scores = _nmsParams2.scores, _maxOutputSize = _nmsParams2.maxOutputSize, _iouThreshold = _nmsParams2.iouThreshold, _scoreThreshold = _nmsParams2.scoreThreshold;
padToMaxOutputSize = getParamValue('padToMaxOutputSize', node, tensorMap, context);
_context.next = 12;
return image.nonMaxSuppressionPaddedAsync(_boxes, _scores, _maxOutputSize, _iouThreshold, _scoreThreshold, padToMaxOutputSize);
case 12:
_result = _context.sent;
return _context.abrupt("return", [_result.selectedIndices, _result.validOutputs]);
case 14:
_nmsParams3 = nmsParams(node, tensorMap, context), _boxes2 = _nmsParams3.boxes, _scores2 = _nmsParams3.scores, _maxOutputSize2 = _nmsParams3.maxOutputSize, _iouThreshold2 = _nmsParams3.iouThreshold, _scoreThreshold2 = _nmsParams3.scoreThreshold;
_context.next = 17;
return image.nonMaxSuppressionAsync(_boxes2, _scores2, _maxOutputSize2, _iouThreshold2, _scoreThreshold2);
case 17:
_context.t1 = _context.sent;
return _context.abrupt("return", [_context.t1]);
case 19:
condition = cast(getParamValue('condition', node, tensorMap, context), 'bool');
_context.next = 22;
return whereAsync(condition);
case 22:
_context.t2 = _context.sent;
_result2 = [_context.t2];
condition.dispose();
return _context.abrupt("return", _result2);
case 26:
return _context.abrupt("return", setdiff1dAsync(getParamValue('x', node, tensorMap, context), getParamValue('y', node, tensorMap, context)));
case 27:
throw TypeError("Node type " + node.op + " is not implemented");
case 28:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return function executeOp(_x, _x2, _x3) {
return _ref.apply(this, arguments);
};
}();
var CATEGORY$5 = 'dynamic';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$6 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'TopKV2':
{
var x = getParamValue('x', node, tensorMap, context);
var k = getParamValue('k', node, tensorMap, context);
var sorted = getParamValue('sorted', node, tensorMap, context);
var result = topk(x, k, sorted);
return [result.values, result.indices];
}
case 'Unique':
{
var _x = getParamValue('x', node, tensorMap, context);
var _result = unique(_x);
return [_result.values, _result.indices];
}
case 'UniqueV2':
{
var _x2 = getParamValue('x', node, tensorMap, context);
var axis = getParamValue('axis', node, tensorMap, context);
var _result2 = unique(_x2, axis);
return [_result2.values, _result2.indices];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$6 = 'evaluation';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$7 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Const':
{
return tensorMap[node.name];
}
case 'PlaceholderWithDefault':
var def = getParamValue('default', node, tensorMap, context);
return [getTensor(node.name, tensorMap, context) || def];
case 'Placeholder':
return [getTensor(node.name, tensorMap, context)];
case 'Identity':
case 'StopGradient':
case 'FakeQuantWithMinMaxVars':
{
// This op is currently ignored.
var _data = getParamValue('x', node, tensorMap, context);
return [cloneTensor(_data)];
}
case 'IdentityN':
return getParamValue('x', node, tensorMap, context).map(function (t) {
return cloneTensor(t);
});
case 'Snapshot':
var snapshot = getParamValue('x', node, tensorMap, context);
return [cloneTensor(snapshot)];
case 'Shape':
return [tensor1d(getParamValue('x', node, tensorMap, context).shape, 'int32')];
case 'ShapeN':
return getParamValue('x', node, tensorMap, context).map(function (t) {
return tensor1d(t.shape);
});
case 'Size':
return [scalar(getParamValue('x', node, tensorMap, context).size, 'int32')];
case 'Rank':
return [scalar(getParamValue('x', node, tensorMap, context).rank, 'int32')];
case 'NoOp':
return [scalar(1)];
case 'Print':
var input = getParamValue('x', node, tensorMap, context);
var data = getParamValue('data', node, tensorMap, context);
var message = getParamValue('message', node, tensorMap, context);
var summarize = getParamValue('summarize', node, tensorMap, context);
console.warn('The graph has a tf.print() operation,' + 'usually used for debugging, which slows down performance.');
console.log(message);
for (var i = 0; i < data.length; i++) {
console.log(Array.prototype.slice.call(data[i].dataSync()).slice(0, summarize));
}
return [input];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$7 = 'graph';
/**
* Hashtable contains a set of tensors, which can be accessed by key.
*/
var HashTable = /*#__PURE__*/function () {
/**
* Constructor of HashTable. Creates a hash table.
*
* @param keyDType `dtype` of the table keys.
* @param valueDType `dtype` of the table values.
*/
function HashTable(keyDType, valueDType) {
this.keyDType = keyDType;
this.valueDType = valueDType;
this.handle = scalar(0); // tslint:disable-next-line: no-any
this.tensorMap = new Map();
keep(this.handle);
}
var _proto = HashTable.prototype;
/**
* Dispose the tensors and handle and clear the hashtable.
*/
_proto.clearAndClose = function clearAndClose() {
this.tensorMap.forEach(function (value) {
return value.dispose();
});
this.tensorMap.clear();
this.handle.dispose();
}
/**
* The number of items in the hash table.
*/
;
_proto.size = function size() {
return this.tensorMap.size;
}
/**
* The number of items in the hash table as a rank-0 tensor.
*/
;
_proto.tensorSize = function tensorSize() {
return scalar(this.size(), 'int32');
}
/**
* Replaces the contents of the table with the specified keys and values.
* @param keys Keys to store in the hashtable.
* @param values Values to store in the hashtable.
*/
;
_proto.import =
/*#__PURE__*/
function () {
var _import2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(keys, values) {
var _this = this;
var $keys;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
this.checkKeyAndValueTensor(keys, values); // We only store the primitive values of the keys, this allows lookup
// to be O(1).
_context.next = 3;
return keys.data();
case 3:
$keys = _context.sent;
// Clear the hashTable before inserting new values.
this.tensorMap.forEach(function (value) {
return value.dispose();
});
this.tensorMap.clear();
return _context.abrupt("return", tidy(function () {
var $values = unstack(values);
var keysLength = $keys.length;
var valuesLength = $values.length;
assert(keysLength === valuesLength, function () {
return "The number of elements doesn't match, keys has " + (keysLength + " elements, the values has " + valuesLength + " ") + "elements.";
});
for (var i = 0; i < keysLength; i++) {
var key = $keys[i];
var value = $values[i];
keep(value);
_this.tensorMap.set(key, value);
}
return _this.handle;
}));
case 7:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function _import(_x, _x2) {
return _import2.apply(this, arguments);
}
return _import;
}()
/**
* Looks up keys in a hash table, outputs the corresponding values.
*
* Performs batch lookups, for every element in the key tensor, `find`
* stacks the corresponding value into the return tensor.
*
* If an element is not present in the table, the given `defaultValue` is
* used.
*
* @param keys Keys to look up. Must have the same type as the keys of the
* table.
* @param defaultValue The scalar `defaultValue` is the value output for keys
* not present in the table. It must also be of the same type as the
* table values.
*/
;
_proto.find =
/*#__PURE__*/
function () {
var _find = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(keys, defaultValue) {
var _this2 = this;
var $keys;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
this.checkKeyAndValueTensor(keys, defaultValue);
_context2.next = 3;
return keys.data();
case 3:
$keys = _context2.sent;
return _context2.abrupt("return", tidy(function () {
var result = [];
for (var i = 0; i < $keys.length; i++) {
var key = $keys[i];
var value = _this2.findWithDefault(key, defaultValue);
result.push(value);
}
return stack(result);
}));
case 5:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function find(_x3, _x4) {
return _find.apply(this, arguments);
}
return find;
}() // tslint:disable-next-line: no-any
;
_proto.findWithDefault = function findWithDefault(key, defaultValue) {
var result = this.tensorMap.get(key);
return result != null ? result : defaultValue;
};
_proto.checkKeyAndValueTensor = function checkKeyAndValueTensor(key, value) {
if (key.dtype !== this.keyDType) {
throw new Error("Expect key dtype " + this.keyDType + ", but got " + ("" + key.dtype));
}
if (value.dtype !== this.valueDType) {
throw new Error("Expect value dtype " + this.valueDType + ", but got " + ("" + value.dtype));
}
};
_createClass(HashTable, [{
key: "id",
get: function get() {
return this.handle.id;
}
}]);
return HashTable;
}();
var executeOp$8 = /*#__PURE__*/function () {
var _ref = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(node, tensorMap, context, resourceManager) {
var keyDType, valueDType, hashTable, handle, keys, values, _hashTable, _handle, _keys, defaultValue, _hashTable2, _handle2, _hashTable3;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.t0 = node.op;
_context.next = _context.t0 === 'HashTable' ? 3 : _context.t0 === 'HashTableV2' ? 3 : _context.t0 === 'LookupTableImport' ? 8 : _context.t0 === 'LookupTableImportV2' ? 8 : _context.t0 === 'LookupTableFind' ? 16 : _context.t0 === 'LookupTableFindV2' ? 16 : _context.t0 === 'LookupTableSize' ? 24 : _context.t0 === 'LookupTableSizeV2' ? 24 : 27;
break;
case 3:
keyDType = getParamValue('keyDType', node, tensorMap, context);
valueDType = getParamValue('valueDType', node, tensorMap, context);
hashTable = new HashTable(keyDType, valueDType);
resourceManager.addHashTable(node.name, hashTable);
return _context.abrupt("return", [hashTable.handle]);
case 8:
handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
keys = getParamValue('keys', node, tensorMap, context);
values = getParamValue('values', node, tensorMap, context);
_hashTable = resourceManager.getHashTableById(handle.id);
_context.next = 14;
return _hashTable.import(keys, values);
case 14:
_context.t1 = _context.sent;
return _context.abrupt("return", [_context.t1]);
case 16:
_handle = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
_keys = getParamValue('keys', node, tensorMap, context);
defaultValue = getParamValue('defaultValue', node, tensorMap, context);
_hashTable2 = resourceManager.getHashTableById(_handle.id);
_context.next = 22;
return _hashTable2.find(_keys, defaultValue);
case 22:
_context.t2 = _context.sent;
return _context.abrupt("return", [_context.t2]);
case 24:
_handle2 = getParamValue('tableHandle', node, tensorMap, context, resourceManager);
_hashTable3 = resourceManager.getHashTableById(_handle2.id);
return _context.abrupt("return", [_hashTable3.tensorSize()]);
case 27:
throw TypeError("Node type " + node.op + " is not implemented");
case 28:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return function executeOp(_x, _x2, _x3, _x4) {
return _ref.apply(this, arguments);
};
}();
var CATEGORY$8 = 'hash_table';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$9 = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'ResizeBilinear':
{
var images = getParamValue('images', node, tensorMap, context);
var size = getParamValue('size', node, tensorMap, context);
var alignCorners = getParamValue('alignCorners', node, tensorMap, context);
var halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
return [image.resizeBilinear(images, [size[0], size[1]], alignCorners, halfPixelCenters)];
}
case 'ResizeNearestNeighbor':
{
var _images = getParamValue('images', node, tensorMap, context);
var _size = getParamValue('size', node, tensorMap, context);
var _alignCorners = getParamValue('alignCorners', node, tensorMap, context);
var _halfPixelCenters = getParamValue('halfPixelCenters', node, tensorMap, context);
return [image.resizeNearestNeighbor(_images, [_size[0], _size[1]], _alignCorners, _halfPixelCenters)];
}
case 'CropAndResize':
{
var image$1 = getParamValue('image', node, tensorMap, context);
var boxes = getParamValue('boxes', node, tensorMap, context);
var boxInd = getParamValue('boxInd', node, tensorMap, context);
var cropSize = getParamValue('cropSize', node, tensorMap, context);
var method = getParamValue('method', node, tensorMap, context);
var extrapolationValue = getParamValue('extrapolationValue', node, tensorMap, context);
return [image.cropAndResize(image$1, boxes, boxInd, cropSize, method, extrapolationValue)];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$9 = 'image';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$a = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Equal':
{
return [equal(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'NotEqual':
{
return [notEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Greater':
{
return [greater(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'GreaterEqual':
{
return [greaterEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Less':
{
return [less(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'LessEqual':
{
return [lessEqual(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'LogicalAnd':
{
return [logicalAnd(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'LogicalNot':
{
return [logicalNot(getParamValue('a', node, tensorMap, context))];
}
case 'LogicalOr':
{
return [logicalOr(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
case 'Select':
case 'SelectV2':
{
return [where(getParamValue('condition', node, tensorMap, context), getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$a = 'logical';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$b = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'BatchMatMul':
case 'BatchMatMulV2':
case 'MatMul':
return [matMul(getParamValue('a', node, tensorMap, context), getParamValue('b', node, tensorMap, context), getParamValue('transposeA', node, tensorMap, context), getParamValue('transposeB', node, tensorMap, context))];
case 'Einsum':
return [einsum.apply(tfOps, [getParamValue('equation', node, tensorMap, context)].concat(getParamValue('tensors', node, tensorMap, context)))];
case 'Transpose':
return [transpose(getParamValue('x', node, tensorMap, context), getParamValue('perm', node, tensorMap, context))];
case '_FusedMatMul':
var _getParamValue = getParamValue('fusedOps', node, tensorMap, context),
extraOp = _getParamValue[0],
activationFunc = _getParamValue[1];
var isBiasAdd = extraOp === 'biasadd';
var isPrelu = activationFunc === 'prelu';
var numArgs = getParamValue('numArgs', node, tensorMap, context);
var leakyreluAlpha = getParamValue('leakyreluAlpha', node, tensorMap, context);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error('Fused MatMul with BiasAdd and Prelu must have two ' + 'extra arguments: bias and alpha.');
}
if (!isPrelu && numArgs !== 1) {
throw new Error('Fused MatMul with BiasAdd must have one extra argument: bias.');
}
}
var _getParamValue2 = getParamValue('args', node, tensorMap, context),
biasArg = _getParamValue2[0],
preluArg = _getParamValue2[1];
return [matMul$1({
a: getParamValue('a', node, tensorMap, context),
b: getParamValue('b', node, tensorMap, context),
transposeA: getParamValue('transposeA', node, tensorMap, context),
transposeB: getParamValue('transposeB', node, tensorMap, context),
bias: biasArg,
activation: activationFunc,
preluActivationWeights: preluArg,
leakyreluAlpha: leakyreluAlpha
})];
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$b = 'matrices';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$c = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'FusedBatchNorm':
case 'FusedBatchNormV2':
{
return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
}
case 'FusedBatchNormV3':
{
return [batchNorm(getParamValue('x', node, tensorMap, context), getParamValue('mean', node, tensorMap, context), getParamValue('variance', node, tensorMap, context), getParamValue('offset', node, tensorMap, context), getParamValue('scale', node, tensorMap, context), getParamValue('epsilon', node, tensorMap, context))];
}
case 'LRN':
{
return [localResponseNormalization(getParamValue('x', node, tensorMap, context), getParamValue('radius', node, tensorMap, context), getParamValue('bias', node, tensorMap, context), getParamValue('alpha', node, tensorMap, context), getParamValue('beta', node, tensorMap, context))];
}
case 'Softmax':
{
return [softmax(getParamValue('x', node, tensorMap, context))];
}
case 'LogSoftmax':
{
return [logSoftmax(getParamValue('x', node, tensorMap, context))];
}
case 'SparseToDense':
{
return [sparseToDense(getParamValue('sparseIndices', node, tensorMap, context), getParamValue('outputShape', node, tensorMap, context), getParamValue('sparseValues', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$c = 'normalization';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$d = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Max':
{
var axis = getParamValue('axis', node, tensorMap, context);
var keepDims = getParamValue('keepDims', node, tensorMap, context);
return [max$5(getParamValue('x', node, tensorMap, context), axis, keepDims)];
}
case 'Mean':
{
var _axis = getParamValue('axis', node, tensorMap, context);
var _keepDims = getParamValue('keepDims', node, tensorMap, context);
return [mean(getParamValue('x', node, tensorMap, context), _axis, _keepDims)];
}
case 'Min':
{
var _axis2 = getParamValue('axis', node, tensorMap, context);
var _keepDims2 = getParamValue('keepDims', node, tensorMap, context);
return [min$9(getParamValue('x', node, tensorMap, context), _axis2, _keepDims2)];
}
case 'Sum':
{
var _axis3 = getParamValue('axis', node, tensorMap, context);
var _keepDims3 = getParamValue('keepDims', node, tensorMap, context);
return [sum$1(getParamValue('x', node, tensorMap, context), _axis3, _keepDims3)];
}
case 'All':
{
var _axis4 = getParamValue('axis', node, tensorMap, context);
var _keepDims4 = getParamValue('keepDims', node, tensorMap, context);
return [all(getParamValue('x', node, tensorMap, context), _axis4, _keepDims4)];
}
case 'Any':
{
var _axis5 = getParamValue('axis', node, tensorMap, context);
var _keepDims5 = getParamValue('keepDims', node, tensorMap, context);
return [any(getParamValue('x', node, tensorMap, context), _axis5, _keepDims5)];
}
case 'ArgMax':
{
var _axis6 = getParamValue('axis', node, tensorMap, context);
return [argMax(getParamValue('x', node, tensorMap, context), _axis6)];
}
case 'ArgMin':
{
var _axis7 = getParamValue('axis', node, tensorMap, context);
return [argMin(getParamValue('x', node, tensorMap, context), _axis7)];
}
case 'Prod':
{
var _axis8 = getParamValue('axis', node, tensorMap, context);
var _keepDims6 = getParamValue('keepDims', node, tensorMap, context);
return [prod(getParamValue('x', node, tensorMap, context), _axis8, _keepDims6)];
}
case 'Cumsum':
{
var _axis9 = getParamValue('axis', node, tensorMap, context);
var exclusive = getParamValue('exclusive', node, tensorMap, context);
var reverse = getParamValue('reverse', node, tensorMap, context);
return [cumsum(getParamValue('x', node, tensorMap, context), _axis9, exclusive, reverse)];
}
case 'Bincount':
var x = getParamValue('x', node, tensorMap, context);
var weights = getParamValue('weights', node, tensorMap, context);
var size = getParamValue('size', node, tensorMap, context);
return [bincount(x, weights, size)];
case 'DenseBincount':
{
var _x = getParamValue('x', node, tensorMap, context);
var _weights = getParamValue('weights', node, tensorMap, context);
var _size = getParamValue('size', node, tensorMap, context);
var binaryOutput = getParamValue('binaryOutput', node, tensorMap, context);
return [denseBincount(_x, _weights, _size, binaryOutput)];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$d = 'reduction';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$e = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'ConcatV2':
case 'Concat':
{
var n = getParamValue('n', node, tensorMap, context);
var axis = getParamValue('axis', node, tensorMap, context);
var inputs = getParamValue('tensors', node, tensorMap, context);
inputs = inputs.slice(0, n);
return [concat(inputs, axis)];
}
case 'Gather':
{
var input = getParamValue('x', node, tensorMap, context);
var indices = getParamValue('indices', node, tensorMap, context);
return [gather(input, cast(indices, 'int32'), 0)];
}
case 'GatherV2':
{
var _axis = getParamValue('axis', node, tensorMap, context);
var batchDims = getParamValue('batchDims', node, tensorMap, context);
var _input = getParamValue('x', node, tensorMap, context);
var _indices = getParamValue('indices', node, tensorMap, context);
return [gather(_input, cast(_indices, 'int32'), _axis, batchDims)];
}
case 'Reverse':
{
var dims = getParamValue('dims', node, tensorMap, context);
var _axis2 = [];
for (var i = 0; i < dims.length; i++) {
if (dims[i]) {
_axis2.push(i);
}
}
var _input2 = getParamValue('x', node, tensorMap, context);
return [reverse(_input2, _axis2)];
}
case 'ReverseV2':
{
var _axis3 = getParamValue('axis', node, tensorMap, context);
var _input3 = getParamValue('x', node, tensorMap, context);
return [reverse(_input3, _axis3)];
}
case 'Slice':
{
// tslint:disable-next-line:no-any
var begin = getParamValue('begin', node, tensorMap, context); // tslint:disable-next-line:no-any
var size = getParamValue('size', node, tensorMap, context);
return [slice$2(getParamValue('x', node, tensorMap, context), begin, size)];
}
case 'StridedSlice':
{
var _begin = getParamValue('begin', node, tensorMap, context);
var end = getParamValue('end', node, tensorMap, context);
var strides = getParamValue('strides', node, tensorMap, context);
var beginMask = getParamValue('beginMask', node, tensorMap, context);
var endMask = getParamValue('endMask', node, tensorMap, context);
var ellipsisMask = getParamValue('ellipsisMask', node, tensorMap, context);
var newAxisMask = getParamValue('newAxisMask', node, tensorMap, context);
var shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context);
var tensor = getParamValue('x', node, tensorMap, context);
return [stridedSlice(tensor, _begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)];
}
case 'Pack':
{
return tidy(function () {
var axis = getParamValue('axis', node, tensorMap, context);
var tensors = getParamValue('tensors', node, tensorMap, context); // Reshape the tensors to the first tensor's shape if they don't
// match.
var shape = tensors[0].shape;
var squeezedShape = squeeze(tensors[0]).shape;
var mapped = tensors.map(function (tensor) {
var sameShape = arraysEqual(tensor.shape, shape);
if (!sameShape && !arraysEqual(squeeze(tensor).shape, squeezedShape)) {
throw new Error('the input tensors shape does not match');
}
return sameShape ? tensor : reshape(tensor, shape);
});
return [stack(mapped, axis)];
});
}
case 'Unpack':
{
var _axis4 = getParamValue('axis', node, tensorMap, context);
var _tensor = getParamValue('tensor', node, tensorMap, context);
return unstack(_tensor, _axis4);
}
case 'Tile':
{
var reps = getParamValue('reps', node, tensorMap, context);
return [tile(getParamValue('x', node, tensorMap, context), reps)];
}
case 'Split':
case 'SplitV':
{
var _axis5 = getParamValue('axis', node, tensorMap, context);
var numOrSizeSplits = getParamValue('numOrSizeSplits', node, tensorMap, context);
var _tensor2 = getParamValue('x', node, tensorMap, context);
return split$1(_tensor2, numOrSizeSplits, _axis5);
}
case 'ScatterNd':
{
var _indices2 = getParamValue('indices', node, tensorMap, context);
var values = getParamValue('values', node, tensorMap, context);
var shape = getParamValue('shape', node, tensorMap, context);
return [scatterND(_indices2, values, shape)];
}
case 'GatherNd':
{
var x = getParamValue('x', node, tensorMap, context);
var _indices3 = getParamValue('indices', node, tensorMap, context);
return [gatherND(x, _indices3)];
}
case 'SparseToDense':
{
var _indices4 = getParamValue('sparseIndices', node, tensorMap, context);
var _shape = getParamValue('outputShape', node, tensorMap, context);
var sparseValues = getParamValue('sparseValues', node, tensorMap, context);
var defaultValue = getParamValue('defaultValue', node, tensorMap, context);
return [sparseToDense(_indices4, sparseValues, _shape, sparseValues.dtype === defaultValue.dtype ? defaultValue : cast(defaultValue, sparseValues.dtype))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$e = 'slice_join';
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$f = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'SparseFillEmptyRows':
{
var _tfOps$sparse$sparseF = sparse.sparseFillEmptyRows(getParamValue('indices', node, tensorMap, context), getParamValue('values', node, tensorMap, context), getParamValue('denseShape', node, tensorMap, context), getParamValue('defaultValue', node, tensorMap, context)),
outputIndices = _tfOps$sparse$sparseF.outputIndices,
outputValues = _tfOps$sparse$sparseF.outputValues,
emptyRowIndicator = _tfOps$sparse$sparseF.emptyRowIndicator,
reverseIndexMap = _tfOps$sparse$sparseF.reverseIndexMap;
return [outputIndices, outputValues, emptyRowIndicator, reverseIndexMap];
}
case 'SparseReshape':
{
var _tfOps$sparse$sparseR = sparse.sparseReshape(getParamValue('inputIndices', node, tensorMap, context), getParamValue('inputShape', node, tensorMap, context), getParamValue('newShape', node, tensorMap, context)),
_outputIndices = _tfOps$sparse$sparseR.outputIndices,
outputShape = _tfOps$sparse$sparseR.outputShape;
return [_outputIndices, outputShape];
}
case 'SparseSegmentMean':
{
var outputData = sparse.sparseSegmentMean(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
return [outputData];
}
case 'SparseSegmentSum':
{
var _outputData = sparse.sparseSegmentSum(getParamValue('data', node, tensorMap, context), getParamValue('indices', node, tensorMap, context), getParamValue('segmentIds', node, tensorMap, context));
return [_outputData];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$f = 'sparse';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$g = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'FFT':
{
return [fft(getParamValue('x', node, tensorMap, context))];
}
case 'IFFT':
{
return [ifft(getParamValue('x', node, tensorMap, context))];
}
case 'RFFT':
{
return [rfft(getParamValue('x', node, tensorMap, context))];
}
case 'IRFFT':
{
return [irfft(getParamValue('x', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$g = 'spectral';
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$h = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'StringNGrams':
{
var _tfOps$string$stringN = string.stringNGrams(getParamValue('data', node, tensorMap, context), getParamValue('dataSplits', node, tensorMap, context), getParamValue('separator', node, tensorMap, context), getParamValue('nGramWidths', node, tensorMap, context), getParamValue('leftPad', node, tensorMap, context), getParamValue('rightPad', node, tensorMap, context), getParamValue('padWidth', node, tensorMap, context), getParamValue('preserveShortSequences', node, tensorMap, context)),
nGrams = _tfOps$string$stringN.nGrams,
nGramsSplits = _tfOps$string$stringN.nGramsSplits;
return [nGrams, nGramsSplits];
}
case 'StringSplit':
{
var _tfOps$string$stringS = string.stringSplit(getParamValue('input', node, tensorMap, context), getParamValue('delimiter', node, tensorMap, context), getParamValue('skipEmpty', node, tensorMap, context)),
indices = _tfOps$string$stringS.indices,
values = _tfOps$string$stringS.values,
shape = _tfOps$string$stringS.shape;
return [indices, values, shape];
}
case 'StringToHashBucketFast':
{
var output = string.stringToHashBucketFast(getParamValue('input', node, tensorMap, context), getParamValue('numBuckets', node, tensorMap, context));
return [output];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$h = 'string';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var executeOp$i = function executeOp(node, tensorMap, context) {
switch (node.op) {
case 'Cast':
{
return [cast(getParamValue('x', node, tensorMap, context), getParamValue('dtype', node, tensorMap, context))];
}
case 'ExpandDims':
{
var axis = getParamValue('axis', node, tensorMap, context);
return [expandDims(getParamValue('x', node, tensorMap, context), axis)];
}
case 'Squeeze':
{
var _axis = getParamValue('axis', node, tensorMap, context);
return [squeeze(getParamValue('x', node, tensorMap, context), _axis)];
}
case 'Reshape':
{
return [reshape(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
}
case 'MirrorPad':
{
return [mirrorPad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('mode', node, tensorMap, context))];
}
case 'PadV2':
case 'Pad':
{
return [pad(getParamValue('x', node, tensorMap, context), getParamValue('padding', node, tensorMap, context), getParamValue('constantValue', node, tensorMap, context))];
}
case 'SpaceToBatchND':
{
var blockShape = getParamValue('blockShape', node, tensorMap, context);
var paddings = getParamValue('paddings', node, tensorMap, context);
return [spaceToBatchND(getParamValue('x', node, tensorMap, context), blockShape, paddings)];
}
case 'BatchToSpaceND':
{
var _blockShape = getParamValue('blockShape', node, tensorMap, context);
var crops = getParamValue('crops', node, tensorMap, context);
return [batchToSpaceND(getParamValue('x', node, tensorMap, context), _blockShape, crops)];
}
case 'DepthToSpace':
{
var blockSize = getParamValue('blockSize', node, tensorMap, context);
var dataFormat = getParamValue('dataFormat', node, tensorMap, context).toUpperCase();
return [depthToSpace(getParamValue('x', node, tensorMap, context), blockSize, dataFormat)];
}
case 'BroadcastTo':
{
return [broadcastTo(getParamValue('x', node, tensorMap, context), getParamValue('shape', node, tensorMap, context))];
}
case 'BroadcastArgs':
{
return [broadcastArgs(getParamValue('s0', node, tensorMap, context), getParamValue('s1', node, tensorMap, context))];
}
default:
throw TypeError("Node type " + node.op + " is not implemented");
}
};
var CATEGORY$i = 'transformation';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Executes the op defined by the node object.
* @param node
* @param tensorMap contains tensors for executed nodes and weights
* @param context contains tensors and information for running the current node.
* @param resourceManager Optional. Contains global resources of the model.
*/
function executeOp$j(node, tensorMap, context, resourceManager) {
var value = function (node, tensorMap, context) {
switch (node.category) {
case 'arithmetic':
return tidy(function () {
return executeOp(node, tensorMap, context);
});
case 'basic_math':
return tidy(function () {
return executeOp$1(node, tensorMap, context);
});
case 'control':
return executeOp$2(node, tensorMap, context);
case 'convolution':
return tidy(function () {
return executeOp$3(node, tensorMap, context);
});
case 'creation':
return tidy(function () {
return executeOp$4(node, tensorMap, context);
});
case 'dynamic':
return executeOp$5(node, tensorMap, context);
case 'evaluation':
return tidy(function () {
return executeOp$6(node, tensorMap, context);
});
case 'image':
return tidy(function () {
return executeOp$9(node, tensorMap, context);
});
case 'graph':
return tidy(function () {
return executeOp$7(node, tensorMap, context);
});
case 'logical':
return tidy(function () {
return executeOp$a(node, tensorMap, context);
});
case 'matrices':
return tidy(function () {
return executeOp$b(node, tensorMap, context);
});
case 'normalization':
return tidy(function () {
return executeOp$c(node, tensorMap, context);
});
case 'reduction':
return tidy(function () {
return executeOp$d(node, tensorMap, context);
});
case 'slice_join':
return tidy(function () {
return executeOp$e(node, tensorMap, context);
});
case 'sparse':
return tidy(function () {
return executeOp$f(node, tensorMap, context);
});
case 'spectral':
return tidy(function () {
return executeOp$g(node, tensorMap, context);
});
case 'string':
return tidy(function () {
return executeOp$h(node, tensorMap, context);
});
case 'transformation':
return tidy(function () {
return executeOp$i(node, tensorMap, context);
});
case 'hash_table':
return executeOp$8(node, tensorMap, context, resourceManager);
case 'custom':
var opMapper = getRegisteredOp(node.op);
if (opMapper && opMapper.customExecutor) {
return opMapper.customExecutor(new NodeValueImpl(node, tensorMap, context));
} else {
throw TypeError("Custom op " + node.op + " is not registered.");
}
default:
throw TypeError("Unknown op '" + node.op + "'. File an issue at " + "https://github.com/tensorflow/tfjs/issues so we can add it" + ", or register a custom execution with tf.registerOp()");
}
}(node, tensorMap, context);
if (isPromise(value)) {
return value.then(function (data) {
return [].concat(data);
});
}
return [].concat(value);
}
/**
* ExecutionContext captures the runtime environment of the node. It keeps
* track of the current frame and iteration for the control flow ops.
*
* For example, typical Dynamic RNN model may contain loops, for which
* TensorFlow will generate graphs with Enter/Exit nodes to control the
* current execution frame, and NextIteration Nodes for iteration id increment.
* For model with branch logic, TensorFLow will generate Switch/Merge ops.
*/
var ExecutionContext = /*#__PURE__*/function () {
function ExecutionContext(weightMap, tensorArrayMap, tensorListMap, functionMap) {
if (weightMap === void 0) {
weightMap = {};
}
if (tensorArrayMap === void 0) {
tensorArrayMap = {};
}
if (tensorListMap === void 0) {
tensorListMap = {};
}
if (functionMap === void 0) {
functionMap = {};
}
this.weightMap = weightMap;
this.tensorArrayMap = tensorArrayMap;
this.tensorListMap = tensorListMap;
this.functionMap = functionMap;
this.rootContext = {
id: 0,
frameName: '',
iterationId: 0
};
this.contexts = [this.rootContext];
this.lastId = 0;
this.generateCurrentContextIds();
}
var _proto = ExecutionContext.prototype;
_proto.newFrame = function newFrame(id, frameName) {
return {
id: id,
frameName: frameName,
iterationId: 0
};
}
/**
* Set the current context
* @param contexts: ExecutionContextInfo[] the current path of execution
* frames
*/
;
_proto.generateCurrentContextIds = function generateCurrentContextIds() {
var names = [];
for (var i = 0; i < this.contexts.length - 1; i++) {
var contexts = this.contexts.slice(0, this.contexts.length - i);
names.push(this.contextIdforContexts(contexts));
}
names.push('');
this._currentContextIds = names;
};
_proto.contextIdforContexts = function contextIdforContexts(contexts) {
return contexts ? contexts.map(function (context) {
return context.id === 0 && context.iterationId === 0 ? '' : context.frameName + "-" + context.iterationId;
}).join('/') : '';
}
/**
* Enter a new frame, a new context is pushed on the current context list.
* @param frameId new frame id
*/
;
_proto.enterFrame = function enterFrame(frameId) {
if (this.contexts) {
this.lastId++;
this.contexts = this.contexts.slice();
this.contexts.push(this.newFrame(this.lastId, frameId));
this._currentContextIds.unshift(this.contextIdforContexts(this.contexts));
}
}
/**
* Exit the current frame, the last context is removed from the current
* context list.
*/
;
_proto.exitFrame = function exitFrame() {
if (this.contexts && this.contexts.length > 1) {
this.contexts = this.contexts.slice();
this.contexts.splice(-1);
this.currentContextIds.shift();
} else {
throw new Error('Cannot exit frame, the context is empty');
}
}
/**
* Enter the next iteration of a loop, the iteration id of last context is
* increased.
*/
;
_proto.nextIteration = function nextIteration() {
if (this.contexts && this.contexts.length > 0) {
this.contexts = this.contexts.slice();
this.lastId++;
var context = Object.assign({}, this.contexts[this.contexts.length - 1]);
context.iterationId += 1;
context.id = this.lastId;
this.contexts.splice(-1, 1, context);
this._currentContextIds.splice(0, 1, this.contextIdforContexts(this.contexts));
} else {
throw new Error('Cannot increase frame iteration, the context is empty');
}
};
_proto.getWeight = function getWeight(name) {
return this.weightMap[name];
};
_proto.addTensorArray = function addTensorArray(tensorArray) {
this.tensorArrayMap[tensorArray.id] = tensorArray;
};
_proto.getTensorArray = function getTensorArray(id) {
return this.tensorArrayMap[id];
};
_proto.addTensorList = function addTensorList(tensorList) {
this.tensorListMap[tensorList.id] = tensorList;
};
_proto.getTensorList = function getTensorList(id) {
return this.tensorListMap[id];
};
_proto.dispose = function dispose(keepIds) {
for (var key in this.tensorArrayMap) {
this.tensorArrayMap[key].clearAndClose(keepIds);
}
for (var _key in this.tensorListMap) {
this.tensorListMap[_key].clearAndClose(keepIds);
}
};
_createClass(ExecutionContext, [{
key: "currentContext",
get: function get() {
return this.contexts;
}
/**
* Returns the current context in string format.
*/
,
set: function set(contexts) {
if (this.contexts !== contexts) {
this.contexts = contexts;
this.generateCurrentContextIds();
}
}
}, {
key: "currentContextId",
get: function get() {
return this._currentContextIds[0];
}
/**
* Returns the current context and all parent contexts in string format.
* This allow access to the nodes in the current and parent frames.
*/
}, {
key: "currentContextIds",
get: function get() {
return this._currentContextIds;
}
}]);
return ExecutionContext;
}();
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Given graph inputs and desired outputs, find the minimal set of nodes
* to execute in order to compute the outputs. In addition return other useful
* info such:
* - Missing inputs needed to compute the output.
* - Whether the subgraph contains dynamic ops (control flow, dynamic shape).
* - Alternative inputs in order to avoid async (dynamic op) execution.
*/
function getExecutionSubgraph(inputs, outputs, weightMap, initNodes) {
var usedNodes = new Set();
var missingInputs = [];
var dynamicNode = null;
var syncInputs = null; // Start with the outputs, going backwards and find all the nodes that are
// needed to compute those outputs.
var seen = new Set();
var inputNodeNames = Object.keys(inputs).map(function (name) {
return parseNodeName(name)[0];
});
var initNodeNames = [];
if (initNodes != null) {
initNodeNames = initNodes.map(function (node) {
return parseNodeName(node.name)[0];
});
}
var frontier = [].concat(outputs);
while (frontier.length > 0) {
var node = frontier.pop();
if (isControlFlow(node) || isDynamicShape(node) || isHashTable(node)) {
if (dynamicNode == null) {
dynamicNode = node;
syncInputs = dynamicNode.children.map(function (child) {
return child.name;
}).filter(function (name) {
return usedNodes.has(name);
});
}
}
usedNodes.add(node.name); // Weights are dead end since we already have their values.
if (weightMap[node.name] != null) {
continue;
} // This node is a dead end since it's one of the user-provided inputs.
if (inputNodeNames.indexOf(node.name) !== -1) {
continue;
} // This node is a dead end since it doesn't have any inputs.
if (initNodeNames.indexOf(node.name) !== -1) {
continue;
}
if (node.inputs.length === 0) {
missingInputs.push(node.name);
continue;
}
node.inputs.forEach(function (input) {
// Don't add to the frontier if it is already there.
if (seen.has(input.name)) {
return;
}
seen.add(input.name);
frontier.push(input);
});
}
return {
inputs: inputs,
outputs: outputs,
usedNodes: usedNodes,
missingInputs: missingInputs,
dynamicNode: dynamicNode,
syncInputs: syncInputs
};
}
/**
* Given the execution info, return a list of nodes in topological order that
* need to be executed to compute the output.
*/
function getNodesInTopologicalOrder(graph, weightMap, executionInfo) {
var usedNodes = executionInfo.usedNodes,
inputs = executionInfo.inputs;
var frontier = [];
var inputNodes = Object.keys(inputs).map(function (name) {
return parseNodeName(name)[0];
}).map(function (name) {
return graph.nodes[name];
});
var initNodes = graph.initNodes;
inputNodes.forEach(function (input) {
if (usedNodes.has(input.name)) {
frontier.push(input);
}
});
graph.weights.forEach(function (weight) {
if (usedNodes.has(weight.name)) {
frontier.push(weight);
}
});
if (initNodes != null) {
initNodes.forEach(function (node) {
if (usedNodes.has(node.name)) {
frontier.push(node);
}
});
}
var seen = new Set();
var orderedNodes = [];
while (frontier.length > 0) {
var node = frontier.pop();
seen.add(node.name);
if (!weightMap[node.name]) {
orderedNodes.push(node);
}
node.children.forEach(function (child) {
if (!seen.has(child.name) && usedNodes.has(child.name) && child.inputs.every(function (input) {
return seen.has(input.name);
})) {
frontier.push(child);
}
});
}
return orderedNodes;
}
var CONTROL_FLOW_OPS = ['Switch', 'Merge', 'Enter', 'Exit', 'NextIteration', 'StatelessIf', 'StatelessWhile', 'if', 'While'];
var DYNAMIC_SHAPE_OPS = ['NonMaxSuppressionV2', 'NonMaxSuppressionV3', 'NonMaxSuppressionV5', 'Where'];
var HASH_TABLE_OPS = ['HashTable', 'HashTableV2', 'LookupTableImport', 'LookupTableImportV2', 'LookupTableFind', 'LookupTableFindV2', 'LookupTableSize', 'LookupTableSizeV2'];
function isControlFlow(node) {
return CONTROL_FLOW_OPS.indexOf(node.op) >= 0;
}
function isDynamicShape(node) {
return DYNAMIC_SHAPE_OPS.indexOf(node.op) >= 0;
}
function isHashTable(node) {
return HASH_TABLE_OPS.indexOf(node.op) >= 0;
}
var GraphExecutor = /*#__PURE__*/function () {
/**
*
* @param graph Graph the model or function graph to be executed.
* @param parent When building function exector you need to set the parent
* executor. Since the weights and function executor maps are set at parant
* level, that function executor can access the function maps and weight maps
* through the parent.
*/
function GraphExecutor(graph, parent) {
var _this = this;
this.graph = graph;
this.parent = parent;
this.compiledMap = new Map();
this._weightMap = {};
this.SEPERATOR = ',';
this._functions = {};
this._functionExecutorMap = {};
this._outputs = graph.outputs;
this._inputs = graph.inputs;
this._initNodes = graph.initNodes;
this._signature = graph.signature;
this._functions = graph.functions; // create sub-graph executors
if (graph.functions != null) {
Object.keys(graph.functions).forEach(function (name) {
_this._functionExecutorMap[name] = new GraphExecutor(graph.functions[name], _this);
});
}
}
var _proto = GraphExecutor.prototype;
_proto.getCompilationKey = function getCompilationKey(inputs, outputs) {
var sortedInputs = inputs.map(function (node) {
return node.name;
}).sort();
var sortedOutputs = outputs.map(function (node) {
return node.name;
}).sort();
return sortedInputs.join(this.SEPERATOR) + '--' + sortedOutputs.join(this.SEPERATOR);
}
/**
* Compiles the inference graph and returns the minimal set of nodes that are
* required for execution, in the correct execution order.
*/
;
_proto.compile = function compile(inputs, outputs) {
var executionInfo = getExecutionSubgraph(inputs, outputs, this.weightMap, this._initNodes);
var missingInputs = executionInfo.missingInputs,
dynamicNode = executionInfo.dynamicNode,
syncInputs = executionInfo.syncInputs;
if (dynamicNode != null) {
throw new Error("This execution contains the node '" + dynamicNode.name + "', which has " + ("the dynamic op '" + dynamicNode.op + "'. Please use ") + "model.executeAsync() instead. Alternatively, to avoid the " + ("dynamic ops, specify the inputs [" + syncInputs + "]"));
}
if (missingInputs.length > 0) {
var outNames = outputs.map(function (n) {
return n.name;
});
var inNames = Object.keys(inputs);
throw new Error("Cannot compute the outputs [" + outNames + "] from the provided inputs " + ("[" + inNames + "]. Missing the following inputs: [" + missingInputs + "]"));
}
return getNodesInTopologicalOrder(this.graph, this.weightMap, executionInfo);
}
/**
* Executes the inference for given input tensors.
* @param inputs Tensor map for the model inputs, keyed by the input node
* names.
* @param outputs Optional. output node name from the Tensorflow model, if
* no outputs are specified, the default outputs of the model would be used.
* You can inspect intermediate nodes of the model by adding them to the
* outputs array.
*/
;
_proto.execute = function execute(inputs, outputs) {
var _this2 = this;
inputs = this.mapInputs(inputs);
var names = Object.keys(inputs).sort();
this.checkInputs(inputs);
this.checkInputShapeAndType(inputs);
outputs = this.mapOutputs(outputs);
this.checkOutputs(outputs);
var inputNodes = names.map(function (name) {
return _this2.graph.nodes[parseNodeName(name)[0]];
});
var outputNodeNames = outputs.map(function (name) {
return parseNodeName(name)[0];
});
var outputNodes = outputNodeNames.map(function (name) {
return _this2.graph.nodes[name];
}); // If no outputs are specified, then use the default outputs of the model.
if (outputNodes.length === 0) {
outputNodes = this._outputs;
}
var compilationKey = this.getCompilationKey(inputNodes, outputNodes); // Do nothing if the compiled graph cache contains the input.
var orderedNodes = this.compiledMap.get(compilationKey);
if (orderedNodes == null) {
orderedNodes = this.compile(inputs, outputNodes);
this.compiledMap.set(compilationKey, orderedNodes);
}
var tensorArrayMap = {};
var tensorListMap = {};
return tidy(function () {
var context = new ExecutionContext(_this2.weightMap, tensorArrayMap, tensorListMap, _this2.functionExecutorMap);
var tensorsMap = Object.assign({}, _this2.weightMap);
Object.keys(inputs).forEach(function (name) {
var _parseNodeName = parseNodeName(name),
nodeName = _parseNodeName[0],
index = _parseNodeName[1];
var tensors = [];
tensors[index] = inputs[name];
tensorsMap[nodeName] = tensors;
});
var tensorsToKeep = _this2.getFrozenTensorIds(tensorsMap);
var intermediateTensorConsumerCount = {};
for (var i = 0; i < orderedNodes.length; i++) {
var node = orderedNodes[i];
if (!tensorsMap[node.name]) {
var tensors = executeOp$j(node, tensorsMap, context, _this2._resourceManager);
if (isPromise(tensors)) {
throw new Error("The execution of the op '" + node.op + "' returned a promise. " + "Please use model.executeAsync() instead.");
}
tensorsMap[node.name] = tensors;
_this2.checkTensorForDisposal(node.name, node, tensorsMap, context, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount);
}
} // dispose the context for the root executor
if (_this2.parent == null) {
context.dispose(tensorsToKeep);
}
return outputs.map(function (name) {
return getTensor(name, tensorsMap, context);
});
});
};
_proto.getFrozenTensorIds = function getFrozenTensorIds(tensorMap) {
var ids = [].concat.apply([], Object.keys(tensorMap).map(function (key) {
return tensorMap[key];
}).map(function (tensors) {
return tensors.map(function (tensor) {
return tensor.id;
});
}));
return new Set(ids);
};
_proto.checkTensorForDisposal = function checkTensorForDisposal(nodeName, node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount) {
// Skip output nodes and any control flow nodes, since its dependency is
// tricky to track correctly.
if (node.category === 'control' || outputNames.indexOf(nodeName) !== -1) {
return;
}
tensorMap[nodeName].forEach(function (tensor) {
if (tensor != null) {
intermediateTensorConsumerCount[tensor.id] = (intermediateTensorConsumerCount[tensor.id] || 0) + node.children.length;
}
});
node.inputs.forEach(function (input) {
// Skip any control flow nodes, since its dependency is tricky to track
// correctly.
if (input.category !== 'control') {
var tensors = getTensorsForCurrentContenxt(input.name, tensorMap, context);
if (tensors != null) {
tensors.forEach(function (tensor) {
if (tensor && !tensor.kept && !tensorsToKeep.has(tensor.id)) {
var count = intermediateTensorConsumerCount[tensor.id];
if (count === 1) {
tensor.dispose();
delete intermediateTensorConsumerCount[tensor.id];
} else if (count != null) {
// only intermediate nodes has count set, inputs and weights are
// not.
intermediateTensorConsumerCount[tensor.id]--;
}
}
});
}
}
});
}
/**
* Executes the inference for given input tensors in Async fashion.
* @param inputs Tensor map for the model inputs, keyed by the input node
* names.
* @param outputs output node name from the Tensorflow model, if no outputs
* are specified, the default outputs of the model would be used. You can
* inspect intermediate nodes of the model by adding them to the outputs
* array.
*/
;
_proto.executeAsync =
/*#__PURE__*/
function () {
var _executeAsync2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(inputs, outputs) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", this._executeAsync(inputs, outputs));
case 1:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function executeAsync(_x, _x2) {
return _executeAsync2.apply(this, arguments);
}
return executeAsync;
}()
/**
* Executes the inference for given input tensors in Async fashion.
* @param inputs Tensor map for the model inputs, keyed by the input node
* names.
* @param outputs Optional. output node name from the Tensorflow model,
* if no outputs are specified, the default outputs of the model would be
* used. You can inspect intermediate nodes of the model by adding them to the
* outputs array.
* @param isFunctionExecution Optional. Flag for executing a function.
* @param tensorArrayMap Optional, global TensorArray map by id. Used for
* function execution.
* @param tensorArrayMap Optinal global TensorList map by id. Used for
* function execution.
*/
;
_proto._executeAsync =
/*#__PURE__*/
function () {
var _executeAsync3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(inputs, outputs, isFunctionExecution, tensorArrayMap, tensorListMap) {
var context, tensorMap, results, outputIds, inputIds, keepIds;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (isFunctionExecution === void 0) {
isFunctionExecution = false;
}
if (tensorArrayMap === void 0) {
tensorArrayMap = {};
}
if (tensorListMap === void 0) {
tensorListMap = {};
}
if (!isFunctionExecution) {
inputs = this.mapInputs(inputs);
this.checkInputs(inputs);
this.checkInputShapeAndType(inputs);
outputs = this.mapOutputs(outputs);
this.checkOutputs(outputs);
}
context = new ExecutionContext(this.weightMap, tensorArrayMap, tensorListMap, this.functionExecutorMap); // Graph with control flow op requires runtime evaluation of the execution
// order, while without control flow the execution order is pre-determined
// in the compile method.
_context2.next = 7;
return this.executeWithControlFlow(inputs, context, outputs, isFunctionExecution);
case 7:
tensorMap = _context2.sent;
results = outputs.map(function (name) {
return getTensor(name, tensorMap, context);
}); // dispose all the intermediate tensors
outputIds = results.map(function (t) {
return t.id;
});
inputIds = Object.keys(inputs).map(function (name) {
return inputs[name].id;
});
keepIds = new Set([].concat(outputIds, inputIds, this.weightIds));
Object.keys(tensorMap).forEach(function (key) {
var tensorArray = tensorMap[key];
tensorArray.forEach(function (tensor) {
if (tensor && !tensor.kept && !tensor.isDisposed && !keepIds.has(tensor.id)) {
tensor.dispose();
}
});
}); // dispose the context for the root executor
if (this.parent == null) {
context.dispose(keepIds);
}
return _context2.abrupt("return", results);
case 15:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function _executeAsync(_x3, _x4, _x5, _x6, _x7) {
return _executeAsync3.apply(this, arguments);
}
return _executeAsync;
}();
_proto.executeFunctionAsync = /*#__PURE__*/function () {
var _executeFunctionAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(inputs, tensorArrayMap, tensorListMap) {
var _this3 = this;
var mappedInputs;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
mappedInputs = inputs.reduce(function (map, tensor, index) {
map[_this3.inputs[index].name] = tensor;
return map;
}, {});
return _context3.abrupt("return", this._executeAsync(mappedInputs, this.outputNodes, true, tensorArrayMap, tensorListMap));
case 2:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function executeFunctionAsync(_x8, _x9, _x10) {
return _executeFunctionAsync.apply(this, arguments);
}
return executeFunctionAsync;
}()
/**
* When there are control flow nodes in the graph, the graph execution use
* ExecutionContext to keep track of the frames and loop iterators.
* @param inputs placeholder tensors for the graph.
* @param context the execution context object for current execution.
* @param outputNames Optional. output node name from the Tensorflow model,
* if no outputs are specified, the default outputs of the model would be
* used. You can inspect intermediate nodes of the model by adding them to the
* outputs array.
* @param isFunctionExecution Flag for executing a function.
*/
;
_proto.executeWithControlFlow =
/*#__PURE__*/
function () {
var _executeWithControlFlow = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(inputs, context, outputNames, isFunctionExecution) {
var _this4 = this;
var names, inputNodes, outputNodeNames, outputNodes, _getExecutionSubgraph, usedNodes, missingInputs, dynamicNode, syncInputs, stack, tensorsMap, intermediateTensorConsumerCount, tensorsToKeep, added, promises, missingOutputs, alternativeMsg;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
names = Object.keys(inputs);
inputNodes = names.map(function (name) {
return _this4.graph.nodes[parseNodeName(name)[0]];
});
outputNodeNames = outputNames.map(function (name) {
return parseNodeName(name)[0];
});
outputNodes = outputNodeNames.map(function (name) {
return _this4.graph.nodes[name];
}); // If no outputs are specified, then use the default outputs of the model.
if (outputNodes.length === 0) {
outputNodes = this._outputs;
}
_getExecutionSubgraph = getExecutionSubgraph(inputs, outputNodes, this.weightMap, this._initNodes), usedNodes = _getExecutionSubgraph.usedNodes, missingInputs = _getExecutionSubgraph.missingInputs, dynamicNode = _getExecutionSubgraph.dynamicNode, syncInputs = _getExecutionSubgraph.syncInputs; // First nodes to execute include inputNodes, weights, and initNodes.
stack = [].concat(inputNodes, this.graph.weights, this._initNodes || []).map(function (node) {
return {
node: node,
contexts: context.currentContext
};
});
tensorsMap = Object.assign({}, this.weightMap);
Object.keys(inputs).forEach(function (name) {
var _parseNodeName2 = parseNodeName(name),
nodeName = _parseNodeName2[0],
index = _parseNodeName2[1];
var tensors = [];
tensors[index] = inputs[name];
tensorsMap[nodeName] = tensors;
});
intermediateTensorConsumerCount = {};
tensorsToKeep = this.getFrozenTensorIds(tensorsMap);
added = {};
case 12:
if (!(stack.length > 0)) {
_context4.next = 18;
break;
}
promises = this.processStack(inputNodes, stack, context, tensorsMap, added, tensorsToKeep, outputNodeNames, intermediateTensorConsumerCount, usedNodes);
_context4.next = 16;
return Promise.all(promises);
case 16:
_context4.next = 12;
break;
case 18:
if (dynamicNode == null && !isFunctionExecution) {
console.warn("This model execution did not contain any nodes with control flow " + "or dynamic output shapes. You can use model.execute() instead.");
}
missingOutputs = outputNodes.filter(function (node) {
return !isControlFlow(node) && !getTensor(node.name, tensorsMap, context);
}).map(function (node) {
return node.name;
});
if (!(missingOutputs.length > 0)) {
_context4.next = 24;
break;
}
alternativeMsg = '';
if (dynamicNode != null) {
alternativeMsg = "Alternatively, to avoid the dynamic ops, use model.execute() " + ("and specify the inputs [" + syncInputs + "]");
}
throw new Error("Cannot compute the outputs [" + missingOutputs + "] from the provided " + ("inputs [" + names + "]. Consider providing the following inputs: ") + ("[" + missingInputs + "]. " + alternativeMsg));
case 24:
return _context4.abrupt("return", tensorsMap);
case 25:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function executeWithControlFlow(_x11, _x12, _x13, _x14) {
return _executeWithControlFlow.apply(this, arguments);
}
return executeWithControlFlow;
}();
_proto.processStack = function processStack(inputNodes, stack, context, tensorMap, added, tensorsToKeep, outputNames, intermediateTensorConsumerCount, usedNodes) {
var _this5 = this;
var promises = [];
var _loop = function _loop() {
var item = stack.pop();
context.currentContext = item.contexts;
var nodeName = ''; // The tensor of the Enter op with isConstant set should be set
// in the parent scope, so it will be available as constant for the
// whole loop.
if (item.node.op === 'Enter' && getParamValue('isConstant', item.node, tensorMap, context)) {
var _getNodeNameAndIndex = getNodeNameAndIndex(item.node.name, context);
nodeName = _getNodeNameAndIndex[0];
} // only process nodes that are not in the tensorMap yet, this include
// inputNodes and internal initNodes.
if (tensorMap[item.node.name] == null) {
var tensors = executeOp$j(item.node, tensorMap, context, _this5._resourceManager);
if (!nodeName) {
var _getNodeNameAndIndex2 = getNodeNameAndIndex(item.node.name, context);
nodeName = _getNodeNameAndIndex2[0];
}
var currentContext = context.currentContext;
if (isPromise(tensors)) {
promises.push(tensors.then(function (t) {
tensorMap[nodeName] = t;
context.currentContext = currentContext;
_this5.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
_this5.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
return t;
}));
} else {
tensorMap[nodeName] = tensors;
_this5.checkTensorForDisposal(nodeName, item.node, tensorMap, context, tensorsToKeep, outputNames, intermediateTensorConsumerCount);
_this5.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
}
} else {
_this5.processChildNodes(item.node, stack, context, tensorMap, added, usedNodes);
}
};
while (stack.length > 0) {
_loop();
}
return promises;
};
_proto.processChildNodes = function processChildNodes(node, stack, context, tensorMap, added, usedNodes) {
node.children.forEach(function (childNode) {
var _getNodeNameAndIndex3 = getNodeNameAndIndex(childNode.name, context),
nodeName = _getNodeNameAndIndex3[0];
if (added[nodeName] || !usedNodes.has(childNode.name)) {
return;
} // Merge op can be pushed if any of its inputs has value.
if (childNode.op === 'Merge') {
if (childNode.inputNames.some(function (name) {
return !!getTensor(name, tensorMap, context);
})) {
added[nodeName] = true;
stack.push({
contexts: context.currentContext,
node: childNode
});
}
} else // Otherwise all inputs must to have value.
if (childNode.inputNames.every(function (name) {
return !!getTensor(name, tensorMap, context);
})) {
added[nodeName] = true;
stack.push({
contexts: context.currentContext,
node: childNode
});
}
});
}
/**
* Releases the memory used by the weight tensors.
*/
;
_proto.dispose = function dispose() {
var _this6 = this;
Object.keys(this.weightMap).forEach(function (key) {
return _this6.weightMap[key].forEach(function (tensor) {
return tensor.dispose();
});
});
};
_proto.checkInputShapeAndType = function checkInputShapeAndType(inputs) {
var _this7 = this;
Object.keys(inputs).forEach(function (name) {
var input = inputs[name];
var _parseNodeName3 = parseNodeName(name),
nodeName = _parseNodeName3[0];
var node = _this7.graph.nodes[nodeName];
if (node.attrParams['shape'] && node.attrParams['shape'].value) {
var shape = node.attrParams['shape'].value;
var match = shape.length === input.shape.length && input.shape.every(function (dim, index) {
return shape[index] === -1 || shape[index] === dim;
});
assert(match, function () {
return "The shape of dict['" + node.name + "'] provided in " + ("model.execute(dict) must be [" + shape + "], but was ") + ("[" + input.shape + "]");
});
}
if (node.attrParams['dtype'] && node.attrParams['dtype'].value) {
assert(input.dtype === node.attrParams['dtype'].value, function () {
return "The dtype of dict['" + node.name + "'] provided in " + "model.execute(dict) must be " + (node.attrParams['dtype'].value + ", but was " + input.dtype);
});
}
});
};
_proto.mapInputs = function mapInputs(inputs) {
var result = {};
for (var inputName in inputs) {
if (this._signature != null && this._signature.inputs != null && this._signature.inputs[inputName] != null) {
var tensor = this._signature.inputs[inputName];
result[tensor.name] = inputs[inputName];
} else {
result[inputName] = inputs[inputName];
}
}
return result;
};
_proto.checkInputs = function checkInputs(inputs) {
var _this8 = this;
var notInGraph = Object.keys(inputs).filter(function (name) {
var _parseNodeName4 = parseNodeName(name),
nodeName = _parseNodeName4[0];
return _this8.graph.nodes[nodeName] == null;
});
if (notInGraph.length > 0) {
throw new Error("The dict provided in model.execute(dict) has " + ("keys: [" + notInGraph + "] that are not part of graph"));
}
};
_proto.mapOutputs = function mapOutputs(outputs) {
var _this9 = this;
return outputs.map(function (name) {
if (_this9._signature != null && _this9._signature.outputs != null && _this9._signature.outputs[name] != null) {
var tensor = _this9._signature.outputs[name];
return tensor.name;
}
return name;
}, {});
};
_proto.checkOutputs = function checkOutputs(outputs) {
var _this10 = this;
outputs.forEach(function (name) {
var _parseNodeName5 = parseNodeName(name),
normalizedName = _parseNodeName5[0];
if (!_this10.graph.nodes[normalizedName]) {
throw new Error("The output '" + name + "' is not found in the graph");
}
});
};
_createClass(GraphExecutor, [{
key: "weightIds",
get: function get() {
return this.parent ? this.parent.weightIds : this._weightIds;
}
}, {
key: "functionExecutorMap",
get: function get() {
return this.parent ? this.parent.functionExecutorMap : this._functionExecutorMap;
}
}, {
key: "weightMap",
get: function get() {
return this.parent ? this.parent.weightMap : this._weightMap;
},
set: function set(weightMap) {
var _ref;
var weightIds = Object.keys(weightMap).map(function (key) {
return weightMap[key].map(function (tensor) {
return tensor.id;
});
});
this._weightIds = (_ref = []).concat.apply(_ref, weightIds);
this._weightMap = weightMap;
}
/**
* Set `ResourceManager` shared by executors of a model.
* @param resourceManager: `ResourceManager` of the `GraphModel`.
*/
}, {
key: "resourceManager",
set: function set(resourceManager) {
this._resourceManager = resourceManager;
}
}, {
key: "inputs",
get: function get() {
return this._inputs.map(function (node) {
return {
name: node.name,
shape: node.attrParams['shape'] ? node.attrParams['shape'].value : undefined,
dtype: node.attrParams['dtype'] ? node.attrParams['dtype'].value : undefined
};
});
}
}, {
key: "outputs",
get: function get() {
return this._outputs.map(function (node) {
return {
name: node.name,
shape: node.attrParams['shape'] ? node.attrParams['shape'].value : undefined,
dtype: node.attrParams['dtype'] ? node.attrParams['dtype'].value : undefined
};
});
}
}, {
key: "inputNodes",
get: function get() {
return this._inputs.map(function (node) {
return node.signatureKey || node.name;
});
}
}, {
key: "outputNodes",
get: function get() {
return this._outputs.map(function (node) {
var name = node.signatureKey || node.name;
return node.defaultOutput ? name + ":" + node.defaultOutput : name;
});
}
}, {
key: "functions",
get: function get() {
var _this11 = this;
return Object.keys(this._functions).reduce(function (map, key) {
map[key] = _this11._functions[key].signature;
return map;
}, {});
}
}]);
return GraphExecutor;
}();
/**
* Contains global resources of a model.
*/
var ResourceManager = /*#__PURE__*/function () {
function ResourceManager(hashTableNameToHandle, hashTableMap) {
if (hashTableNameToHandle === void 0) {
hashTableNameToHandle = {};
}
if (hashTableMap === void 0) {
hashTableMap = {};
}
this.hashTableNameToHandle = hashTableNameToHandle;
this.hashTableMap = hashTableMap;
}
/**
* Register a `HashTable` in the resource manager.
*
* The `HashTable` can be retrieved by `resourceManager.getHashTableById`,
* where id is the table handle tensor's id.
*
* @param name Op node name that creates the `HashTable`.
* @param hashTable The `HashTable` to be added to resource manager.
*/
var _proto = ResourceManager.prototype;
_proto.addHashTable = function addHashTable(name, hashTable) {
this.hashTableNameToHandle[name] = hashTable.handle;
this.hashTableMap[hashTable.id] = hashTable;
}
/**
* Get the table handle by node name.
* @param name Op node name that creates the `HashTable`. This name is also
* used in the inputs list of lookup and import `HashTable` ops.
*/
;
_proto.getHashTableHandleByName = function getHashTableHandleByName(name) {
return this.hashTableNameToHandle[name];
}
/**
* Get the actual `HashTable` by its handle tensor's id.
* @param id The id of the handle tensor.
*/
;
_proto.getHashTableById = function getHashTableById(id) {
return this.hashTableMap[id];
}
/**
* Dispose `ResourceManager`, including its hashTables and tensors in them.
*/
;
_proto.dispose = function dispose() {
for (var key in this.hashTableMap) {
this.hashTableMap[key].clearAndClose();
delete this.hashTableMap[key];
}
for (var name in this.hashTableNameToHandle) {
this.hashTableNameToHandle[name].dispose();
delete this.hashTableNameToHandle[name];
}
};
return ResourceManager;
}();
var TFHUB_SEARCH_PARAM = '?tfjs-format=file';
var DEFAULT_MODEL_NAME = 'model.json';
/**
* A `tf.GraphModel` is a directed, acyclic graph built from a
* SavedModel GraphDef and allows inference execution.
*
* A `tf.GraphModel` can only be created by loading from a model converted from
* a [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) using
* the command line converter tool and loaded via `tf.loadGraphModel`.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
var GraphModel = /*#__PURE__*/function () {
/**
* @param modelUrl url for the model, or an `io.IOHandler`.
* @param weightManifestUrl url for the weight file generated by
* scripts/convert.py script.
* @param requestOption options for Request, which allows to send credentials
* and custom headers.
* @param onProgress Optional, progress callback function, fired periodically
* before the load is completed.
*/
function GraphModel(modelUrl, loadOptions) {
if (loadOptions === void 0) {
loadOptions = {};
}
this.modelUrl = modelUrl;
this.loadOptions = loadOptions;
this.version = 'n/a';
if (loadOptions == null) {
this.loadOptions = {};
}
this.resourceManager = new ResourceManager();
} // Returns the version information for the tensorflow model GraphDef.
var _proto = GraphModel.prototype;
_proto.findIOHandler = function findIOHandler() {
var path = this.modelUrl;
if (path.load != null) {
// Path is an IO Handler.
this.handler = path;
} else if (this.loadOptions.requestInit != null) {
this.handler = browserHTTPRequest(path, this.loadOptions);
} else {
var handlers = getLoadHandlers(path, this.loadOptions);
if (handlers.length === 0) {
// For backward compatibility: if no load handler can be found,
// assume it is a relative http path.
handlers.push(browserHTTPRequest(path, this.loadOptions));
} else if (handlers.length > 1) {
throw new Error("Found more than one (" + handlers.length + ") load handlers for " + ("URL '" + [path] + "'"));
}
this.handler = handlers[0];
}
}
/**
* Loads the model and weight files, construct the in memory weight map and
* compile the inference graph.
*/
;
_proto.load =
/*#__PURE__*/
function () {
var _load = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var artifacts;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
this.findIOHandler();
if (!(this.handler.load == null)) {
_context.next = 3;
break;
}
throw new Error('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.');
case 3:
_context.next = 5;
return this.handler.load();
case 5:
artifacts = _context.sent;
return _context.abrupt("return", this.loadSync(artifacts));
case 7:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function load() {
return _load.apply(this, arguments);
}
return load;
}()
/**
* Synchronously construct the in memory weight map and
* compile the inference graph. Also initialize hashtable if any.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
;
_proto.loadSync = function loadSync(artifacts) {
this.artifacts = artifacts;
var graph = this.artifacts.modelTopology;
var signature;
if (this.artifacts.userDefinedMetadata != null && this.artifacts.userDefinedMetadata.signature != null) {
signature = // tslint:disable-next-line:no-any
this.artifacts.userDefinedMetadata.signature;
} else {
signature = this.artifacts.signature;
}
this.signature = signature;
this.version = graph.versions.producer + "." + graph.versions.minConsumer;
var weightMap = decodeWeights(this.artifacts.weightData, this.artifacts.weightSpecs);
this.executor = new GraphExecutor(OperationMapper.Instance.transformGraph(graph, this.signature));
this.executor.weightMap = this.convertTensorMapToTensorsMap(weightMap); // Attach a model-level resourceManager to each executor to share resources,
// such as `HashTable`.
this.executor.resourceManager = this.resourceManager;
if (artifacts.modelInitializer != null && artifacts.modelInitializer.node != null) {
var initializer = OperationMapper.Instance.transformGraph(artifacts.modelInitializer);
this.initializer = new GraphExecutor(initializer);
this.initializer.weightMap = this.executor.weightMap; // Attach a model-level resourceManager to the initializer, the
// hashTables created from when executing the initializer will be stored
// in the resourceManager.
this.initializer.resourceManager = this.resourceManager;
this.initializer.executeAsync({}, []);
}
return true;
}
/**
* Save the configuration and/or weights of the GraphModel.
*
* An `IOHandler` is an object that has a `save` method of the proper
* signature defined. The `save` method manages the storing or
* transmission of serialized data ("artifacts") that represent the
* model's topology and weights onto or via a specific medium, such as
* file downloads, local storage, IndexedDB in the web browser and HTTP
* requests to a server. TensorFlow.js provides `IOHandler`
* implementations for a number of frequently used saving mediums, such as
* `tf.io.browserDownloads` and `tf.io.browserLocalStorage`. See `tf.io`
* for more details.
*
* This method also allows you to refer to certain types of `IOHandler`s
* as URL-like string shortcuts, such as 'localstorage://' and
* 'indexeddb://'.
*
* Example 1: Save `model`'s topology and weights to browser [local
* storage](https://developer.mozilla.org/en-US/docs/Web/API/Window/localStorage);
* then load it back.
*
* ```js
* const modelUrl =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
* const model = await tf.loadGraphModel(modelUrl);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
*
* const saveResults = await model.save('localstorage://my-model-1');
*
* const loadedModel = await tf.loadGraphModel('localstorage://my-model-1');
* console.log('Prediction from loaded model:');
* model.predict(zeros).print();
* ```
*
* @param handlerOrURL An instance of `IOHandler` or a URL-like,
* scheme-based string shortcut for `IOHandler`.
* @param config Options for saving the model.
* @returns A `Promise` of `SaveResult`, which summarizes the result of
* the saving, such as byte sizes of the saved artifacts for the model's
* topology and weight values.
*
* @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true}
*/
;
_proto.save =
/*#__PURE__*/
function () {
var _save = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(handlerOrURL, config) {
var handlers;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (!(typeof handlerOrURL === 'string')) {
_context2.next = 9;
break;
}
handlers = getSaveHandlers(handlerOrURL);
if (!(handlers.length === 0)) {
_context2.next = 6;
break;
}
throw new Error("Cannot find any save handlers for URL '" + handlerOrURL + "'");
case 6:
if (!(handlers.length > 1)) {
_context2.next = 8;
break;
}
throw new Error("Found more than one (" + handlers.length + ") save handlers for " + ("URL '" + handlerOrURL + "'"));
case 8:
handlerOrURL = handlers[0];
case 9:
if (!(handlerOrURL.save == null)) {
_context2.next = 11;
break;
}
throw new Error('GraphModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.');
case 11:
return _context2.abrupt("return", handlerOrURL.save(this.artifacts));
case 12:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function save(_x, _x2) {
return _save.apply(this, arguments);
}
return save;
}()
/**
* Execute the inference for the input tensors.
*
* @param input The input tensors, when there is single input for the model,
* inputs param should be a `tf.Tensor`. For models with mutliple inputs,
* inputs params should be in either `tf.Tensor`[] if the input order is
* fixed, or otherwise NamedTensorMap format.
*
* For model with multiple inputs, we recommend you use NamedTensorMap as the
* input type, if you use `tf.Tensor`[], the order of the array needs to
* follow the
* order of inputNodes array. @see {@link GraphModel.inputNodes}
*
* You can also feed any intermediate nodes using the NamedTensorMap as the
* input type. For example, given the graph
* InputNode => Intermediate => OutputNode,
* you can execute the subgraph Intermediate => OutputNode by calling
* model.execute('IntermediateNode' : tf.tensor(...));
*
* This is useful for models that uses tf.dynamic_rnn, where the intermediate
* state needs to be fed manually.
*
* For batch inference execution, the tensors for each input need to be
* concatenated together. For example with mobilenet, the required input shape
* is [1, 244, 244, 3], which represents the [batch, height, width, channel].
* If we are provide a batched data of 100 images, the input tensor should be
* in the shape of [100, 244, 244, 3].
*
* @param config Prediction configuration for specifying the batch size and
* output node names. Currently the batch size option is ignored for graph
* model.
*
* @returns Inference result tensors. The output would be single `tf.Tensor`
* if model has single output node, otherwise Tensor[] or NamedTensorMap[]
* will be returned for model with multiple outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.predict = function predict(inputs, config) {
return this.execute(inputs, this.outputNodes);
};
_proto.normalizeInputs = function normalizeInputs(inputs) {
if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) {
// The input is already a NamedTensorMap.
return inputs;
}
inputs = Array.isArray(inputs) ? inputs : [inputs];
if (inputs.length !== this.inputNodes.length) {
throw new Error('Input tensor count mismatch,' + ("the graph model has " + this.inputNodes.length + " placeholders, ") + ("while there are " + inputs.length + " input tensors."));
}
return this.inputNodes.reduce(function (map, inputName, i) {
map[inputName] = inputs[i];
return map;
}, {});
};
_proto.normalizeOutputs = function normalizeOutputs(outputs) {
outputs = outputs || this.outputNodes;
return !Array.isArray(outputs) ? [outputs] : outputs;
}
/**
* Executes inference for the model for given input tensors.
* @param inputs tensor, tensor array or tensor map of the inputs for the
* model, keyed by the input node names.
* @param outputs output node name from the Tensorflow model, if no
* outputs are specified, the default outputs of the model would be used.
* You can inspect intermediate nodes of the model by adding them to the
* outputs array.
*
* @returns A single tensor if provided with a single output or no outputs
* are provided and there is only one default output, otherwise return a
* tensor array. The order of the tensor array is the same as the outputs
* if provided, otherwise the order of outputNodes attribute of the model.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.execute = function execute(inputs, outputs) {
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
var result = this.executor.execute(inputs, outputs);
return result.length > 1 ? result : result[0];
}
/**
* Executes inference for the model for given input tensors in async
* fashion, use this method when your model contains control flow ops.
* @param inputs tensor, tensor array or tensor map of the inputs for the
* model, keyed by the input node names.
* @param outputs output node name from the Tensorflow model, if no outputs
* are specified, the default outputs of the model would be used. You can
* inspect intermediate nodes of the model by adding them to the outputs
* array.
*
* @returns A Promise of single tensor if provided with a single output or
* no outputs are provided and there is only one default output, otherwise
* return a tensor map.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.executeAsync =
/*#__PURE__*/
function () {
var _executeAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(inputs, outputs) {
var result;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
inputs = this.normalizeInputs(inputs);
outputs = this.normalizeOutputs(outputs);
_context3.next = 4;
return this.executor.executeAsync(inputs, outputs);
case 4:
result = _context3.sent;
return _context3.abrupt("return", result.length > 1 ? result : result[0]);
case 6:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function executeAsync(_x3, _x4) {
return _executeAsync.apply(this, arguments);
}
return executeAsync;
}();
_proto.convertTensorMapToTensorsMap = function convertTensorMapToTensorsMap(map) {
return Object.keys(map).reduce(function (newMap, key) {
newMap[key] = [map[key]];
return newMap;
}, {});
}
/**
* Releases the memory used by the weight tensors and resourceManager.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
;
_proto.dispose = function dispose() {
this.executor.dispose();
if (this.initializer) {
this.initializer.dispose();
}
this.resourceManager.dispose();
};
_createClass(GraphModel, [{
key: "modelVersion",
get: function get() {
return this.version;
}
}, {
key: "inputNodes",
get: function get() {
return this.executor.inputNodes;
}
}, {
key: "outputNodes",
get: function get() {
return this.executor.outputNodes;
}
}, {
key: "inputs",
get: function get() {
return this.executor.inputs;
}
}, {
key: "outputs",
get: function get() {
return this.executor.outputs;
}
}, {
key: "weights",
get: function get() {
return this.executor.weightMap;
}
}, {
key: "metadata",
get: function get() {
return this.artifacts.userDefinedMetadata;
}
}, {
key: "modelSignature",
get: function get() {
return this.signature;
}
}]);
return GraphModel;
}();
/**
* Load a graph model given a URL to the model definition.
*
* Example of loading MobileNetV2 from a URL and making a prediction with a
* zeros input:
*
* ```js
* const modelUrl =
* 'https://storage.googleapis.com/tfjs-models/savedmodel/mobilenet_v2_1.0_224/model.json';
* const model = await tf.loadGraphModel(modelUrl);
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
*
* Example of loading MobileNetV2 from a TF Hub URL and making a prediction with
* a zeros input:
*
* ```js
* const modelUrl =
* 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
* const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
* const zeros = tf.zeros([1, 224, 224, 3]);
* model.predict(zeros).print();
* ```
* @param modelUrl The url or an `io.IOHandler` that loads the model.
* @param options Options for the HTTP request, which allows to send credentials
* and custom headers.
*
* @doc {heading: 'Models', subheading: 'Loading'}
*/
function loadGraphModel(_x5, _x6) {
return _loadGraphModel.apply(this, arguments);
}
function _loadGraphModel() {
_loadGraphModel = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(modelUrl, options) {
var model;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
if (options === void 0) {
options = {};
}
if (!(modelUrl == null)) {
_context4.next = 3;
break;
}
throw new Error('modelUrl in loadGraphModel() cannot be null. Please provide a url ' + 'or an IOHandler that loads the model');
case 3:
if (options == null) {
options = {};
}
if (options.fromTFHub) {
if (modelUrl.load == null) {
if (!modelUrl.endsWith('/')) {
modelUrl = modelUrl + '/';
}
modelUrl = "" + modelUrl + DEFAULT_MODEL_NAME + TFHUB_SEARCH_PARAM;
}
}
model = new GraphModel(modelUrl, options);
_context4.next = 8;
return model.load();
case 8:
return _context4.abrupt("return", model);
case 9:
case "end":
return _context4.stop();
}
}
}, _callee4);
}));
return _loadGraphModel.apply(this, arguments);
}
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$3 = '3.9.0';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Apply a mapping function to a nested structure in a recursive manner.
*
* The result of the mapping is an object with the same nested structure (i.e.,
* of arrays and dicts) as the input, except that some subtrees are replaced,
* according to the results of the mapping function.
*
* Mappings are memoized. Thus, if the nested structure contains the same
* object in multiple positions, the output will contain the same mapped object
* in those positions. Cycles are not supported, however.
*
* @param input: The object to which to apply the mapping function.
* @param mapFn: A function that expects a single node of the object tree, and
* returns a `DeepMapResult`. The `DeepMapResult` either provides a
* replacement value for that node (i.e., replacing the subtree), or indicates
* that the node should be processed recursively.
*/
function deepMap(input, mapFn) {
return deepMapInternal(input, mapFn);
}
/**
* @param seen: A Map of known object mappings (i.e., memoized results of
* `mapFn()`)
* @param containedIn: An set containing objects on the reference path currently
* being processed (used to detect cycles).
*/
function deepMapInternal(input, mapFn, seen, containedIn) {
if (seen === void 0) {
seen = new Map();
}
if (containedIn === void 0) {
containedIn = new Set();
}
if (input == null) {
return null;
}
if (containedIn.has(input)) {
throw new Error('Circular references are not supported.');
}
if (seen.has(input)) {
return seen.get(input);
}
var result = mapFn(input);
if (result.recurse && result.value !== null) {
throw new Error('A deep map function may not return both a value and recurse=true.');
}
if (!result.recurse) {
seen.set(input, result.value);
return result.value;
} else if (isIterable$1(input)) {
// tslint:disable-next-line:no-any
var mappedIterable = Array.isArray(input) ? [] : {};
containedIn.add(input);
for (var k in input) {
var child = input[k];
var childResult = deepMapInternal(child, mapFn, seen, containedIn);
mappedIterable[k] = childResult;
}
containedIn.delete(input);
return mappedIterable;
} else {
throw new Error("Can't recurse into non-iterable type: " + input);
}
} // TODO(soergel, kangyizhang) Reconsider naming of deepZip() to avoid confusion
// with zip()
/**
* Zip nested structures together in a recursive manner.
*
* This has the effect of transposing or pivoting data, e.g. converting it from
* a row-major representation to a column-major representation.
*
* For example, `deepZip([{a: 1, b: 2}, {a: 3, b: 4}])` returns
* `{a: [1, 3], b: [2, 4]}`.
*
* The inputs should all have the same nested structure (i.e., of arrays and
* dicts). The result is a single object with the same nested structure, where
* the leaves are arrays collecting the values of the inputs at that location
* (or, optionally, the result of a custom function applied to those arrays).
*
* @param inputs: An array of the objects to zip together.
* @param zipFn: (optional) A function that expects an array of elements at a
* single node of the object tree, and returns a `DeepMapResult`. The
* `DeepMapResult` either provides a result value for that node (i.e.,
* representing the subtree), or indicates that the node should be processed
* recursively. The default zipFn recurses as far as possible and places
* arrays at the leaves.
*/
function deepZip(inputs, zipFn) {
if (zipFn === void 0) {
zipFn = zipToList;
}
return deepZipInternal(inputs, zipFn);
}
/**
* @param containedIn: An set containing objects on the reference path currently
* being processed (used to detect cycles).
*/
function deepZipInternal(inputs, zipFn, containedIn) {
if (containedIn === void 0) {
containedIn = new Set();
}
// The recursion follows the structure of input 0; it's assumed that all the
// other inputs have the same structure.
var input = inputs[0];
if (containedIn.has(input)) {
throw new Error('Circular references are not supported.');
}
var result = zipFn(inputs);
if (result.recurse && result.value !== null) {
throw new Error('A deep zip function may not return both a value and recurse=true.');
}
if (!result.recurse) {
return result.value;
} else if (isIterable$1(input)) {
// tslint:disable-next-line:no-any
var mappedIterable = Array.isArray(input) ? [] : {};
containedIn.add(input);
var _loop = function _loop(k) {
var children = inputs.map(function (x) {
return x[k];
});
var childResult = deepZipInternal(children, zipFn, containedIn);
mappedIterable[k] = childResult;
};
for (var k in input) {
_loop(k);
}
containedIn.delete(input);
return mappedIterable;
} else {
throw new Error("Can't recurse into non-iterable type: " + input);
}
} // tslint:disable-next-line:no-any
function zipToList(x) {
if (x === null) {
return null;
} // TODO(soergel): validate array type?
if (isIterable$1(x[0])) {
return {
value: null,
recurse: true
};
} else {
return {
value: x,
recurse: false
};
}
}
/**
* Apply an async mapping function to a nested structure in a recursive manner.
*
* This first creates a nested structure of Promises, and then awaits all of
* those, resulting in a single Promise for a resolved nested structure.
*
* The result of the mapping is an object with the same nested structure (i.e.,
* of arrays and dicts) as the input, except that some subtrees are replaced,
* according to the results of the mapping function.
*
* Mappings are memoized. Thus, if the nested structure contains the same
* object in multiple positions, the output will contain the same mapped object
* in those positions. Cycles are not supported, however.
*
* @param input: The object to which to apply the mapping function.
* @param mapFn: A function that expects a single node of the object tree, and
* returns a `DeepMapAsyncResult`. The `DeepMapAsyncResult` either provides
* a `Promise` for a replacement value for that node (i.e., replacing the
* subtree), or indicates that the node should be processed recursively. Note
* that the decision whether or not to recurse must be made immediately; only
* the mapped value may be promised.
*/
function deepMapAndAwaitAll(_x, _x2) {
return _deepMapAndAwaitAll.apply(this, arguments);
}
/**
* Determine whether the argument is iterable.
*
* @returns true if the argument is an array or any non-Tensor object.
*/
// tslint:disable-next-line:no-any
function _deepMapAndAwaitAll() {
_deepMapAndAwaitAll = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(input, mapFn) {
var seen, _i, _Array$from, key, value, mappedValue, result;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
seen = new Map(); // First do a normal deepMap, collecting Promises in 'seen' as a side effect.
deepMapInternal(input, mapFn, seen); // Replace the Promises in 'seen' in place.
// Note TypeScript provides no async map iteration, and regular map iteration
// is broken too, so sadly we have to do Array.from() to make it work.
// (There's no advantage to Promise.all(), and that would be tricky anyway.)
_i = 0, _Array$from = Array.from(seen.keys());
case 3:
if (!(_i < _Array$from.length)) {
_context.next = 14;
break;
}
key = _Array$from[_i];
value = seen.get(key);
if (!isPromise(value)) {
_context.next = 11;
break;
}
_context.next = 9;
return value;
case 9:
mappedValue = _context.sent;
seen.set(key, mappedValue);
case 11:
_i++;
_context.next = 3;
break;
case 14:
// Normal deepMap again, this time filling in the resolved values.
// It's unfortunate that we have to do two passes.
// TODO(soergel): test performance and think harder about a fast solution.
result = deepMapInternal(input, mapFn, seen);
return _context.abrupt("return", result);
case 16:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _deepMapAndAwaitAll.apply(this, arguments);
}
function isIterable$1(obj) {
var isTextDecoder = false;
if (env().get('IS_BROWSER')) {
isTextDecoder = obj instanceof TextDecoder;
} else {
// tslint:disable-next-line:no-require-imports
var _require = require('string_decoder'),
StringDecoder = _require.StringDecoder;
isTextDecoder = obj instanceof StringDecoder;
}
return obj != null && !ArrayBuffer.isView(obj) && (Array.isArray(obj) || typeof obj === 'object' && !(obj instanceof Tensor) && !(obj instanceof Promise) && !isTextDecoder);
}
/**
* Determine whether the argument can be converted to Tensor.
*
* Tensors, primitives, arrays, and TypedArrays all qualify; anything else does
* not.
*
* @returns true if the argument can be converted to Tensor.
*/
// tslint:disable-next-line:no-any
function canTensorify(obj) {
return obj == null || isPrimitive(obj) || Array.isArray(obj) || typeof obj === 'object' && obj instanceof Tensor || isTypedArray$1(obj);
}
/**
* Returns true if the given `value` is a primitive type. Otherwise returns
* false. This is equivalant to node util.isPrimitive
*/
function isPrimitive(value) {
return value === null || typeof value !== 'object' && typeof value !== 'function';
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
function deepClone(container) {
return deepMap(container, cloneIfTensor);
} // tslint:disable-next-line: no-any
function cloneIfTensor(item) {
if (item instanceof Tensor) {
return {
value: item.clone(),
recurse: false
};
} else if (isIterable$1(item)) {
return {
value: null,
recurse: true
};
} else {
return {
value: item,
recurse: false
};
}
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
/**
* A ring buffer, providing O(1) FIFO, LIFO, and related operations.
*/
var RingBuffer = /*#__PURE__*/function () {
/**
* Constructs a `RingBuffer`.
* @param capacity The number of items that the buffer can accomodate.
*/
function RingBuffer(capacity) {
this.capacity = capacity; // Note we store the indices in the range 0 <= index < 2*capacity.
// This allows us to distinguish the full from the empty case.
// See https://www.snellman.net/blog/archive/2016-12-13-ring-buffers/
this.begin = 0; // inclusive
this.end = 0; // exclusive
if (capacity == null) {
throw new RangeError('Can\'t create a ring buffer of unknown capacity.');
}
if (capacity < 1) {
throw new RangeError('Can\'t create ring buffer of capacity < 1.');
}
this.data = new Array(capacity);
this.doubledCapacity = 2 * capacity;
}
/**
* Map any index into the range 0 <= index < 2*capacity.
*/
var _proto = RingBuffer.prototype;
_proto.wrap = function wrap(index) {
// don't trust % on negative numbers
while (index < 0) {
index += this.doubledCapacity;
}
return index % this.doubledCapacity;
};
_proto.get = function get(index) {
if (index < 0) {
throw new RangeError('Can\'t get item at a negative index.');
}
return this.data[index % this.capacity];
};
_proto.set = function set(index, value) {
if (index < 0) {
throw new RangeError('Can\'t set item at a negative index.');
}
this.data[index % this.capacity] = value;
}
/**
* Returns the current number of items in the buffer.
*/
;
_proto.length = function length() {
var length = this.end - this.begin;
if (length < 0) {
length = this.doubledCapacity + length;
}
return length;
}
/**
* Reports whether the buffer is full.
* @returns true if the number of items in the buffer equals its capacity, and
* false otherwise.
*/
;
_proto.isFull = function isFull() {
return this.length() === this.capacity;
}
/**
* Reports whether the buffer is empty.
* @returns true if the number of items in the buffer equals zero, and
* false otherwise.
*/
;
_proto.isEmpty = function isEmpty() {
return this.length() === 0;
}
/**
* Adds an item to the end of the buffer.
*/
;
_proto.push = function push(value) {
if (this.isFull()) {
throw new RangeError('Ring buffer is full.');
}
this.set(this.end, value);
this.end = this.wrap(this.end + 1);
}
/**
* Adds many items to the end of the buffer, in order.
*/
;
_proto.pushAll = function pushAll(values) {
for (var _iterator = _createForOfIteratorHelperLoose(values), _step; !(_step = _iterator()).done;) {
var value = _step.value;
this.push(value);
}
}
/**
* Removes and returns the last item in the buffer.
*/
;
_proto.pop = function pop() {
if (this.isEmpty()) {
throw new RangeError('Ring buffer is empty.');
}
this.end = this.wrap(this.end - 1);
var result = this.get(this.end);
this.set(this.end, undefined);
return result;
}
/**
* Adds an item to the beginning of the buffer.
*/
;
_proto.unshift = function unshift(value) {
if (this.isFull()) {
throw new RangeError('Ring buffer is full.');
}
this.begin = this.wrap(this.begin - 1);
this.set(this.begin, value);
}
/**
* Removes and returns the first item in the buffer.
*/
;
_proto.shift = function shift() {
if (this.isEmpty()) {
throw new RangeError('Ring buffer is empty.');
}
var result = this.get(this.begin);
this.set(this.begin, undefined);
this.begin = this.wrap(this.begin + 1);
return result;
}
/**
* Removes and returns a specific item in the buffer, and moves the last item
* to the vacated slot. This is useful for implementing a shuffling stream.
* Note that this operation necessarily scrambles the original order.
*
* @param relativeIndex: the index of the item to remove, relative to the
* first item in the buffer (e.g., hiding the ring nature of the underlying
* storage).
*/
;
_proto.shuffleExcise = function shuffleExcise(relativeIndex) {
if (this.isEmpty()) {
throw new RangeError('Ring buffer is empty.');
}
var index = this.wrap(this.begin + relativeIndex);
var result = this.get(index);
this.set(index, this.pop());
return result;
};
return RingBuffer;
}();
var GrowingRingBuffer = /*#__PURE__*/function (_RingBuffer) {
_inheritsLoose(GrowingRingBuffer, _RingBuffer);
/**
* Constructs a `GrowingRingBuffer`.
*/
function GrowingRingBuffer() {
return _RingBuffer.call(this, GrowingRingBuffer.INITIAL_CAPACITY) || this;
}
var _proto = GrowingRingBuffer.prototype;
_proto.isFull = function isFull() {
return false;
};
_proto.push = function push(value) {
if (_RingBuffer.prototype.isFull.call(this)) {
this.expand();
}
_RingBuffer.prototype.push.call(this, value);
};
_proto.unshift = function unshift(value) {
if (_RingBuffer.prototype.isFull.call(this)) {
this.expand();
}
_RingBuffer.prototype.unshift.call(this, value);
}
/**
* Doubles the capacity of the buffer.
*/
;
_proto.expand = function expand() {
var newCapacity = this.capacity * 2;
var newData = new Array(newCapacity);
var len = this.length(); // Rotate the buffer to start at index 0 again, since we can't just
// allocate more space at the end.
for (var i = 0; i < len; i++) {
newData[i] = this.get(this.wrap(this.begin + i));
}
this.data = newData;
this.capacity = newCapacity;
this.doubledCapacity = 2 * this.capacity;
this.begin = 0;
this.end = len;
};
return GrowingRingBuffer;
}(RingBuffer);
GrowingRingBuffer.INITIAL_CAPACITY = 32;
// This lets us avoid using either third-party stream libraries or
// recent TypeScript language support requiring polyfills.
/**
* Create a `LazyIterator` from an array of items.
*/
function iteratorFromItems(items) {
return new ArrayIterator(items);
}
/**
* Create a `LazyIterator` of incrementing integers.
*/
function iteratorFromIncrementing(start) {
var i = start;
return iteratorFromFunction(function () {
return {
value: i++,
done: false
};
});
}
/**
* Create a `LazyIterator` from a function.
*
* ```js
* let i = -1;
* const func = () =>
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
* const iter = tf.data.iteratorFromFunction(func);
* await iter.forEachAsync(e => console.log(e));
* ```
*
* @param func A function that produces data on each call.
*/
function iteratorFromFunction(func) {
return new FunctionCallIterator(func);
}
/**
* Create a `LazyIterator` by concatenating underlying streams, which are
* themselves provided as a stream.
*
* This can also be thought of as a "stream flatten" operation.
*
* @param baseIterators A stream of streams to be concatenated.
* @param baseErrorHandler An optional function that can intercept `Error`s
* raised during a `next()` call on the base stream. This function can decide
* whether the error should be propagated, whether the error should be
* ignored, or whether the base stream should be terminated.
*/
function iteratorFromConcatenated(baseIterators, baseErrorHandler) {
return new ChainedIterator(baseIterators, baseErrorHandler);
}
/**
* Create a `LazyIterator` by concatenating streams produced by calling a
* stream-generating function a given number of times.
*
* Since a `LazyIterator` is read-once, it cannot be repeated, but this
* function can be used to achieve a similar effect:
*
* LazyIterator.ofConcatenatedFunction(() => new MyIterator(), 6);
*
* @param iteratorFunc: A function that produces a new stream on each call.
* @param count: The number of times to call the function.
* @param baseErrorHandler An optional function that can intercept `Error`s
* raised during a `next()` call on the base stream. This function can decide
* whether the error should be propagated, whether the error should be
* ignored, or whether the base stream should be terminated.
*/
function iteratorFromConcatenatedFunction(iteratorFunc, count, baseErrorHandler) {
return iteratorFromConcatenated(iteratorFromFunction(iteratorFunc).take(count), baseErrorHandler);
}
/**
* Create a `LazyIterator` by zipping together an array, dict, or nested
* structure of `LazyIterator`s (and perhaps additional constants).
*
* The underlying streams must provide elements in a consistent order such
* that they correspond.
*
* Typically, the underlying streams should have the same number of
* elements. If they do not, the behavior is determined by the
* `mismatchMode` argument.
*
* The nested structure of the `iterators` argument determines the
* structure of elements in the resulting iterator.
*
* @param iterators: An array or object containing LazyIterators at the
* leaves.
* @param mismatchMode: Determines what to do when one underlying iterator
* is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
* causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
* causes the zipped iterator to terminate with the furst underlying
* streams, so elements remaining on the longer streams are ignored.
* `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
* in nulls for the exhausted streams, until all streams are exhausted.
*/
function iteratorFromZipped(iterators, mismatchMode) {
if (mismatchMode === void 0) {
mismatchMode = ZipMismatchMode.FAIL;
}
return new ZipIterator(iterators, mismatchMode);
}
/**
* An asynchronous iterator, providing lazy access to a potentially
* unbounded stream of elements.
*
* Iterator can be obtained from a dataset:
* `const iter = await dataset.iterator();`
*/
var LazyIterator = /*#__PURE__*/function () {
function LazyIterator() {}
var _proto = LazyIterator.prototype;
/**
* Collect all remaining elements of a bounded stream into an array.
* Obviously this will succeed only for small streams that fit in memory.
* Useful for testing.
*
* @returns A Promise for an array of stream elements, which will resolve
* when the stream is exhausted.
*/
_proto.toArray =
/*#__PURE__*/
function () {
var _toArray = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var result, x;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
result = [];
_context.next = 3;
return this.next();
case 3:
x = _context.sent;
case 4:
if (x.done) {
_context.next = 11;
break;
}
result.push(x.value);
_context.next = 8;
return this.next();
case 8:
x = _context.sent;
_context.next = 4;
break;
case 11:
return _context.abrupt("return", result);
case 12:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function toArray() {
return _toArray.apply(this, arguments);
}
return toArray;
}()
/**
* Collect all elements of this dataset into an array with prefetching 100
* elements. This is useful for testing, because the prefetch changes the
* order in which the Promises are resolved along the processing pipeline.
* This may help expose bugs where results are dependent on the order of
* Promise resolution rather than on the logical order of the stream (i.e.,
* due to hidden mutable state).
*
* @returns A Promise for an array of stream elements, which will resolve
* when the stream is exhausted.
*/
;
_proto.toArrayForTest =
/*#__PURE__*/
function () {
var _toArrayForTest = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var stream, result, x;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
stream = this.prefetch(100);
result = [];
_context2.next = 4;
return stream.next();
case 4:
x = _context2.sent;
case 5:
if (x.done) {
_context2.next = 12;
break;
}
result.push(x.value);
_context2.next = 9;
return stream.next();
case 9:
x = _context2.sent;
_context2.next = 5;
break;
case 12:
return _context2.abrupt("return", result);
case 13:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function toArrayForTest() {
return _toArrayForTest.apply(this, arguments);
}
return toArrayForTest;
}()
/**
* Draw items from the stream until it is exhausted.
*
* This can be useful when the stream has side effects but no output. In
* that case, calling this function guarantees that the stream will be
* fully processed.
*/
;
_proto.resolveFully =
/*#__PURE__*/
function () {
var _resolveFully = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var x;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
_context3.next = 2;
return this.next();
case 2:
x = _context3.sent;
case 3:
if (x.done) {
_context3.next = 9;
break;
}
_context3.next = 6;
return this.next();
case 6:
x = _context3.sent;
_context3.next = 3;
break;
case 9:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function resolveFully() {
return _resolveFully.apply(this, arguments);
}
return resolveFully;
}()
/**
* Draw items from the stream until it is exhausted, or a predicate fails.
*
* This can be useful when the stream has side effects but no output. In
* that case, calling this function guarantees that the stream will be
* fully processed.
*/
;
_proto.resolveWhile =
/*#__PURE__*/
function () {
var _resolveWhile = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(predicate) {
var x, shouldContinue;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
_context4.next = 2;
return this.next();
case 2:
x = _context4.sent;
shouldContinue = predicate(x.value);
case 4:
if (!(!x.done && shouldContinue)) {
_context4.next = 11;
break;
}
_context4.next = 7;
return this.next();
case 7:
x = _context4.sent;
shouldContinue = predicate(x.value);
_context4.next = 4;
break;
case 11:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function resolveWhile(_x) {
return _resolveWhile.apply(this, arguments);
}
return resolveWhile;
}()
/**
* Handles errors thrown on this stream using a provided handler function.
*
* @param handler A function that handles any `Error` thrown during a `next()`
* call and returns true if the stream should continue (dropping the failed
* call) or false if the stream should quietly terminate. If the handler
* itself throws (or rethrows) an `Error`, that will be propagated.
*
* @returns A `LazyIterator` of elements passed through from upstream,
* possibly filtering or terminating on upstream `next()` calls that
* throw an `Error`.
*/
;
_proto.handleErrors = function handleErrors(handler) {
return new ErrorHandlingLazyIterator(this, handler);
} // TODO(soergel): Implement reduce() etc.
/**
* Filters this stream according to `predicate`.
*
* @param predicate A function mapping a stream element to a boolean or a
* `Promise` for one.
*
* @returns A `LazyIterator` of elements for which the predicate was true.
*/
;
_proto.filter = function filter(predicate) {
return new FilterIterator(this, predicate);
}
/**
* Maps this stream through a 1-to-1 transform.
*
* @param transform A function mapping a stream element to a transformed
* element.
*
* @returns A `LazyIterator` of transformed elements.
*/
;
_proto.map = function map(transform) {
return new MapIterator(this, transform);
}
/**
* Maps this stream through an async 1-to-1 transform.
*
* @param transform A function mapping a stream element to a `Promise` for a
* transformed stream element.
*
* @returns A `LazyIterator` of transformed elements.
*/
;
_proto.mapAsync = function mapAsync(transform) {
return new AsyncMapIterator(this, transform);
}
/**
* Maps this stream through a 1-to-1 transform, forcing serial execution.
*
* @param transform A function mapping a stream element to a transformed
* element.
*
* @returns A `LazyIterator` of transformed elements.
*/
;
_proto.serialMapAsync = function serialMapAsync(transform) {
return new AsyncMapIterator(this, transform).serial();
}
/**
* Maps this stream through a 1-to-many transform.
*
* @param transform A function mapping a stream element to an array of
* transformed elements.
*
* @returns A `DataStream` of transformed elements.
*/
;
_proto.flatmap = function flatmap(transform) {
return new FlatmapIterator(this, transform);
}
/**
* Apply a function to every element of the stream.
*
* @param f A function to apply to each stream element.
*/
;
_proto.forEachAsync =
/*#__PURE__*/
function () {
var _forEachAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5(f) {
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
return _context5.abrupt("return", this.map(f).resolveFully());
case 1:
case "end":
return _context5.stop();
}
}
}, _callee5, this);
}));
function forEachAsync(_x2) {
return _forEachAsync.apply(this, arguments);
}
return forEachAsync;
}()
/**
* Apply a function to every element of the stream, forcing serial execution.
*
* @param f A function to apply to each stream element. Should return 'true'
* to indicate that the stream should continue, or 'false' to cause it to
* terminate.
*/
;
_proto.serialForEach =
/*#__PURE__*/
function () {
var _serialForEach = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee6(f) {
return regeneratorRuntime.wrap(function _callee6$(_context6) {
while (1) {
switch (_context6.prev = _context6.next) {
case 0:
return _context6.abrupt("return", this.serialMapAsync(f).resolveWhile(function (x) {
return x === true;
}));
case 1:
case "end":
return _context6.stop();
}
}
}, _callee6, this);
}));
function serialForEach(_x3) {
return _serialForEach.apply(this, arguments);
}
return serialForEach;
}()
/**
* Groups elements into batches, represented as arrays of elements.
*
* We can think of the elements of this iterator as 'rows' (even if they are
* nested structures). By the same token, consecutive values for a given
* key within the elements form a 'column'. This matches the usual sense of
* 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
*
* Thus, "Row-major" means that the resulting batch is simply a collection of
* rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
* form, which is needed for vectorized computation.
*
* @param batchSize The number of elements desired per batch.
* @param smallLastBatch Whether to emit the final batch when it has fewer
* than batchSize elements. Default true.
* @returns A `LazyIterator` of batches of elements, represented as arrays
* of the original element type.
*/
;
_proto.rowMajorBatch = function rowMajorBatch(batchSize, smallLastBatch) {
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
return new RowMajorBatchIterator(this, batchSize, smallLastBatch);
}
/**
* Groups elements into batches, represented in column-major form.
*
* We can think of the elements of this iterator as 'rows' (even if they are
* nested structures). By the same token, consecutive values for a given
* key within the elements form a 'column'. This matches the usual sense of
* 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
*
* Thus, "column-major" means that the resulting batch is a (potentially
* nested) structure representing the columns. Each column entry, then,
* contains a collection of the values found in that column for a range of
* input elements. This representation allows for vectorized computation, in
* contrast to the row-major form.
*
* The inputs should all have the same nested structure (i.e., of arrays and
* dicts). The result is a single object with the same nested structure,
* where the leaves are arrays collecting the values of the inputs at that
* location (or, optionally, the result of a custom function applied to those
* arrays).
*
* @param batchSize The number of elements desired per batch.
* @param smallLastBatch Whether to emit the final batch when it has fewer
* than batchSize elements. Default true.
* @param zipFn: (optional) A function that expects an array of elements at a
* single node of the object tree, and returns a `DeepMapResult`. The
* `DeepMapResult` either provides a result value for that node (i.e.,
* representing the subtree), or indicates that the node should be processed
* recursively. The default zipFn recurses as far as possible and places
* arrays at the leaves.
* @returns A `LazyIterator` of batches of elements, represented as an object
* with collections at the leaves.
*/
;
_proto.columnMajorBatch = function columnMajorBatch(batchSize, smallLastBatch, // tslint:disable-next-line:no-any
zipFn) {
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
if (zipFn === void 0) {
zipFn = zipToList;
}
// First collect the desired number of input elements as a row-major batch.
var rowBatches = this.rowMajorBatch(batchSize, smallLastBatch); // Now 'rotate' or 'pivot' the data, collecting all values from each column
// in the batch (i.e., for each key within the elements) into an array.
return rowBatches.map(function (x) {
return deepZip(x, zipFn);
});
}
/**
* Concatenate this `LazyIterator` with another.
*
* @param iterator A `LazyIterator` to be concatenated onto this one.
* @param baseErrorHandler An optional function that can intercept `Error`s
* raised during a `next()` call on the base stream. This function can
* decide whether the error should be propagated, whether the error should
* be ignored, or whether the base stream should be terminated.
* @returns A `LazyIterator`.
*/
;
_proto.concatenate = function concatenate(iterator, baseErrorHandler) {
return new ChainedIterator(iteratorFromItems([this, iterator]), baseErrorHandler);
}
/**
* Limits this stream to return at most `count` items.
*
* @param count The maximum number of items to provide from the stream. If
* a negative or undefined value is given, the entire stream is returned
* unaltered.
*/
;
_proto.take = function take(count) {
if (count < 0 || count == null) {
return this;
}
return new TakeIterator(this, count);
}
/**
* Skips the first `count` items in this stream.
*
* @param count The number of items to skip. If a negative or undefined
* value is given, the entire stream is returned unaltered.
*/
;
_proto.skip = function skip(count) {
if (count < 0 || count == null) {
return this;
}
return new SkipIterator(this, count);
}
/**
* Prefetch the first `bufferSize` items in this stream.
*
* Note this prefetches Promises, but makes no guarantees about when those
* Promises resolve.
*
* @param bufferSize: An integer specifying the number of elements to be
* prefetched.
*/
;
_proto.prefetch = function prefetch(bufferSize) {
return new PrefetchIterator(this, bufferSize);
} // TODO(soergel): deep sharded shuffle, where supported
/**
* Randomly shuffles the elements of this stream.
*
* @param bufferSize: An integer specifying the number of elements from
* this stream from which the new stream will sample.
* @param seed: (Optional.) An integer specifying the random seed that
* will be used to create the distribution.
*/
;
_proto.shuffle = function shuffle(windowSize, seed) {
return new ShuffleIterator(this, windowSize, seed);
}
/**
* Force an iterator to execute serially: each next() call will await the
* prior one, so that they cannot execute concurrently.
*/
;
_proto.serial = function serial() {
return new SerialIterator(this);
};
return LazyIterator;
}(); // ============================================================================
// The following private classes serve to implement the chainable methods
// on LazyIterator. Unfortunately they can't be placed in separate files,
// due to resulting trouble with circular imports.
// ============================================================================
// Iterators that just extend LazyIterator directly
// ============================================================================
var ArrayIterator = /*#__PURE__*/function (_LazyIterator) {
_inheritsLoose(ArrayIterator, _LazyIterator);
function ArrayIterator(items) {
var _this;
_this = _LazyIterator.call(this) || this;
_this.items = items;
_this.trav = 0;
return _this;
}
var _proto2 = ArrayIterator.prototype;
_proto2.summary = function summary() {
return "Array of " + this.items.length + " items";
};
_proto2.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee7() {
var item;
return regeneratorRuntime.wrap(function _callee7$(_context7) {
while (1) {
switch (_context7.prev = _context7.next) {
case 0:
if (!(this.trav >= this.items.length)) {
_context7.next = 2;
break;
}
return _context7.abrupt("return", {
value: null,
done: true
});
case 2:
item = this.items[this.trav];
this.trav++;
return _context7.abrupt("return", {
value: deepClone(item),
done: false
});
case 5:
case "end":
return _context7.stop();
}
}
}, _callee7, this);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}();
return ArrayIterator;
}(LazyIterator);
var FunctionCallIterator = /*#__PURE__*/function (_LazyIterator2) {
_inheritsLoose(FunctionCallIterator, _LazyIterator2);
function FunctionCallIterator(nextFn) {
var _this2;
_this2 = _LazyIterator2.call(this) || this;
_this2.nextFn = nextFn;
return _this2;
}
var _proto3 = FunctionCallIterator.prototype;
_proto3.summary = function summary() {
return "Function call";
};
_proto3.next = /*#__PURE__*/function () {
var _next2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee8() {
return regeneratorRuntime.wrap(function _callee8$(_context8) {
while (1) {
switch (_context8.prev = _context8.next) {
case 0:
_context8.prev = 0;
return _context8.abrupt("return", this.nextFn());
case 4:
_context8.prev = 4;
_context8.t0 = _context8["catch"](0);
// Modify the error message but leave the stack trace intact
_context8.t0.message = "Error thrown while iterating through a dataset: " + _context8.t0.message;
throw _context8.t0;
case 8:
case "end":
return _context8.stop();
}
}
}, _callee8, this, [[0, 4]]);
}));
function next() {
return _next2.apply(this, arguments);
}
return next;
}();
return FunctionCallIterator;
}(LazyIterator);
var SerialIterator = /*#__PURE__*/function (_LazyIterator3) {
_inheritsLoose(SerialIterator, _LazyIterator3);
function SerialIterator(upstream) {
var _this3;
_this3 = _LazyIterator3.call(this) || this;
_this3.upstream = upstream;
_this3.lastRead = Promise.resolve({
value: null,
done: false
});
return _this3;
}
var _proto4 = SerialIterator.prototype;
_proto4.summary = function summary() {
return this.upstream.summary() + " -> Serial";
};
_proto4.next = /*#__PURE__*/function () {
var _next3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee9() {
var _this4 = this;
return regeneratorRuntime.wrap(function _callee9$(_context9) {
while (1) {
switch (_context9.prev = _context9.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this4.serialNext();
});
return _context9.abrupt("return", this.lastRead);
case 2:
case "end":
return _context9.stop();
}
}
}, _callee9, this);
}));
function next() {
return _next3.apply(this, arguments);
}
return next;
}();
_proto4.serialNext = /*#__PURE__*/function () {
var _serialNext = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee10() {
return regeneratorRuntime.wrap(function _callee10$(_context10) {
while (1) {
switch (_context10.prev = _context10.next) {
case 0:
return _context10.abrupt("return", this.upstream.next());
case 1:
case "end":
return _context10.stop();
}
}
}, _callee10, this);
}));
function serialNext() {
return _serialNext.apply(this, arguments);
}
return serialNext;
}();
return SerialIterator;
}(LazyIterator);
var SkipIterator = /*#__PURE__*/function (_LazyIterator4) {
_inheritsLoose(SkipIterator, _LazyIterator4);
function SkipIterator(upstream, maxCount) {
var _this5;
_this5 = _LazyIterator4.call(this) || this;
_this5.upstream = upstream;
_this5.maxCount = maxCount; // Local state that should not be clobbered by out-of-order execution.
_this5.count = 0;
_this5.lastRead = Promise.resolve({
value: null,
done: false
});
return _this5;
}
var _proto5 = SkipIterator.prototype;
_proto5.summary = function summary() {
return this.upstream.summary() + " -> Skip";
};
_proto5.next = /*#__PURE__*/function () {
var _next4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee11() {
var _this6 = this;
return regeneratorRuntime.wrap(function _callee11$(_context11) {
while (1) {
switch (_context11.prev = _context11.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this6.serialNext();
});
return _context11.abrupt("return", this.lastRead);
case 2:
case "end":
return _context11.stop();
}
}
}, _callee11, this);
}));
function next() {
return _next4.apply(this, arguments);
}
return next;
}();
_proto5.serialNext = /*#__PURE__*/function () {
var _serialNext2 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee12() {
var skipped;
return regeneratorRuntime.wrap(function _callee12$(_context12) {
while (1) {
switch (_context12.prev = _context12.next) {
case 0:
if (!(this.count++ < this.maxCount)) {
_context12.next = 9;
break;
}
_context12.next = 3;
return this.upstream.next();
case 3:
skipped = _context12.sent;
if (!skipped.done) {
_context12.next = 6;
break;
}
return _context12.abrupt("return", skipped);
case 6:
dispose(skipped.value);
_context12.next = 0;
break;
case 9:
return _context12.abrupt("return", this.upstream.next());
case 10:
case "end":
return _context12.stop();
}
}
}, _callee12, this);
}));
function serialNext() {
return _serialNext2.apply(this, arguments);
}
return serialNext;
}();
return SkipIterator;
}(LazyIterator);
var TakeIterator = /*#__PURE__*/function (_LazyIterator5) {
_inheritsLoose(TakeIterator, _LazyIterator5);
function TakeIterator(upstream, maxCount) {
var _this7;
_this7 = _LazyIterator5.call(this) || this;
_this7.upstream = upstream;
_this7.maxCount = maxCount;
_this7.count = 0;
return _this7;
}
var _proto6 = TakeIterator.prototype;
_proto6.summary = function summary() {
return this.upstream.summary() + " -> Take";
};
_proto6.next = /*#__PURE__*/function () {
var _next5 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee13() {
return regeneratorRuntime.wrap(function _callee13$(_context13) {
while (1) {
switch (_context13.prev = _context13.next) {
case 0:
if (!(this.count++ >= this.maxCount)) {
_context13.next = 2;
break;
}
return _context13.abrupt("return", {
value: null,
done: true
});
case 2:
return _context13.abrupt("return", this.upstream.next());
case 3:
case "end":
return _context13.stop();
}
}
}, _callee13, this);
}));
function next() {
return _next5.apply(this, arguments);
}
return next;
}();
return TakeIterator;
}(LazyIterator); // Note this batch just groups items into row-wise element arrays.
// Rotating these to a column-wise representation happens only at the dataset
// level.
var RowMajorBatchIterator = /*#__PURE__*/function (_LazyIterator6) {
_inheritsLoose(RowMajorBatchIterator, _LazyIterator6);
function RowMajorBatchIterator(upstream, batchSize, enableSmallLastBatch) {
var _this8;
if (enableSmallLastBatch === void 0) {
enableSmallLastBatch = true;
}
_this8 = _LazyIterator6.call(this) || this;
_this8.upstream = upstream;
_this8.batchSize = batchSize;
_this8.enableSmallLastBatch = enableSmallLastBatch;
_this8.lastRead = Promise.resolve({
value: null,
done: false
});
return _this8;
}
var _proto7 = RowMajorBatchIterator.prototype;
_proto7.summary = function summary() {
return this.upstream.summary() + " -> RowMajorBatch";
};
_proto7.next = /*#__PURE__*/function () {
var _next6 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee14() {
var _this9 = this;
return regeneratorRuntime.wrap(function _callee14$(_context14) {
while (1) {
switch (_context14.prev = _context14.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this9.serialNext();
});
return _context14.abrupt("return", this.lastRead);
case 2:
case "end":
return _context14.stop();
}
}
}, _callee14, this);
}));
function next() {
return _next6.apply(this, arguments);
}
return next;
}();
_proto7.serialNext = /*#__PURE__*/function () {
var _serialNext3 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee15() {
var batch, item;
return regeneratorRuntime.wrap(function _callee15$(_context15) {
while (1) {
switch (_context15.prev = _context15.next) {
case 0:
batch = [];
case 1:
if (!(batch.length < this.batchSize)) {
_context15.next = 12;
break;
}
_context15.next = 4;
return this.upstream.next();
case 4:
item = _context15.sent;
if (!item.done) {
_context15.next = 9;
break;
}
if (!(this.enableSmallLastBatch && batch.length > 0)) {
_context15.next = 8;
break;
}
return _context15.abrupt("return", {
value: batch,
done: false
});
case 8:
return _context15.abrupt("return", {
value: null,
done: true
});
case 9:
batch.push(item.value);
_context15.next = 1;
break;
case 12:
return _context15.abrupt("return", {
value: batch,
done: false
});
case 13:
case "end":
return _context15.stop();
}
}
}, _callee15, this);
}));
function serialNext() {
return _serialNext3.apply(this, arguments);
}
return serialNext;
}();
return RowMajorBatchIterator;
}(LazyIterator);
var FilterIterator = /*#__PURE__*/function (_LazyIterator7) {
_inheritsLoose(FilterIterator, _LazyIterator7);
function FilterIterator(upstream, predicate) {
var _this10;
_this10 = _LazyIterator7.call(this) || this;
_this10.upstream = upstream;
_this10.predicate = predicate;
_this10.lastRead = Promise.resolve({
value: null,
done: false
});
return _this10;
}
var _proto8 = FilterIterator.prototype;
_proto8.summary = function summary() {
return this.upstream.summary() + " -> Filter";
};
_proto8.next = /*#__PURE__*/function () {
var _next7 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee16() {
var _this11 = this;
return regeneratorRuntime.wrap(function _callee16$(_context16) {
while (1) {
switch (_context16.prev = _context16.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this11.serialNext();
});
return _context16.abrupt("return", this.lastRead);
case 2:
case "end":
return _context16.stop();
}
}
}, _callee16, this);
}));
function next() {
return _next7.apply(this, arguments);
}
return next;
}();
_proto8.serialNext = /*#__PURE__*/function () {
var _serialNext4 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee17() {
var item;
return regeneratorRuntime.wrap(function _callee17$(_context17) {
while (1) {
switch (_context17.prev = _context17.next) {
case 0:
if (!true) {
_context17.next = 9;
break;
}
_context17.next = 3;
return this.upstream.next();
case 3:
item = _context17.sent;
if (!(item.done || this.predicate(item.value))) {
_context17.next = 6;
break;
}
return _context17.abrupt("return", item);
case 6:
dispose(item.value);
_context17.next = 0;
break;
case 9:
case "end":
return _context17.stop();
}
}
}, _callee17, this);
}));
function serialNext() {
return _serialNext4.apply(this, arguments);
}
return serialNext;
}();
return FilterIterator;
}(LazyIterator);
var MapIterator = /*#__PURE__*/function (_LazyIterator8) {
_inheritsLoose(MapIterator, _LazyIterator8);
function MapIterator(upstream, transform) {
var _this12;
_this12 = _LazyIterator8.call(this) || this;
_this12.upstream = upstream;
_this12.transform = transform;
return _this12;
}
var _proto9 = MapIterator.prototype;
_proto9.summary = function summary() {
return this.upstream.summary() + " -> Map";
};
_proto9.next = /*#__PURE__*/function () {
var _next8 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee18() {
var item, inputTensors, mapped, outputTensors, _iterator, _step, t;
return regeneratorRuntime.wrap(function _callee18$(_context18) {
while (1) {
switch (_context18.prev = _context18.next) {
case 0:
_context18.next = 2;
return this.upstream.next();
case 2:
item = _context18.sent;
if (!item.done) {
_context18.next = 5;
break;
}
return _context18.abrupt("return", {
value: null,
done: true
});
case 5:
inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
// That's why we have to remember the input Tensors above, and then
// below dispose only those that were not passed through to the output.
// Note too that the transform function is responsible for tidying
// any intermediate Tensors. Here we are concerned only about the
// inputs.
mapped = this.transform(item.value);
outputTensors = getTensorsInContainer(mapped); // TODO(soergel) faster intersection
// TODO(soergel) move to tf.disposeExcept(in, out)?
for (_iterator = _createForOfIteratorHelperLoose(inputTensors); !(_step = _iterator()).done;) {
t = _step.value;
if (!isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return _context18.abrupt("return", {
value: mapped,
done: false
});
case 10:
case "end":
return _context18.stop();
}
}
}, _callee18, this);
}));
function next() {
return _next8.apply(this, arguments);
}
return next;
}();
return MapIterator;
}(LazyIterator);
var ErrorHandlingLazyIterator = /*#__PURE__*/function (_LazyIterator9) {
_inheritsLoose(ErrorHandlingLazyIterator, _LazyIterator9);
function ErrorHandlingLazyIterator(upstream, handler) {
var _this13;
_this13 = _LazyIterator9.call(this) || this;
_this13.upstream = upstream;
_this13.handler = handler;
_this13.count = 0;
_this13.lastRead = Promise.resolve({
value: null,
done: false
});
return _this13;
}
var _proto10 = ErrorHandlingLazyIterator.prototype;
_proto10.summary = function summary() {
return this.upstream.summary() + " -> handleErrors";
};
_proto10.next = /*#__PURE__*/function () {
var _next9 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee19() {
var _this14 = this;
return regeneratorRuntime.wrap(function _callee19$(_context19) {
while (1) {
switch (_context19.prev = _context19.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this14.serialNext();
});
return _context19.abrupt("return", this.lastRead);
case 2:
case "end":
return _context19.stop();
}
}
}, _callee19, this);
}));
function next() {
return _next9.apply(this, arguments);
}
return next;
}();
_proto10.serialNext = /*#__PURE__*/function () {
var _serialNext5 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee20() {
return regeneratorRuntime.wrap(function _callee20$(_context20) {
while (1) {
switch (_context20.prev = _context20.next) {
case 0:
if (!true) {
_context20.next = 13;
break;
}
_context20.prev = 1;
_context20.next = 4;
return this.upstream.next();
case 4:
return _context20.abrupt("return", _context20.sent);
case 7:
_context20.prev = 7;
_context20.t0 = _context20["catch"](1);
if (this.handler(_context20.t0)) {
_context20.next = 11;
break;
}
return _context20.abrupt("return", {
value: null,
done: true
});
case 11:
_context20.next = 0;
break;
case 13:
case "end":
return _context20.stop();
}
}
}, _callee20, this, [[1, 7]]);
}));
function serialNext() {
return _serialNext5.apply(this, arguments);
}
return serialNext;
}();
return ErrorHandlingLazyIterator;
}(LazyIterator);
var AsyncMapIterator = /*#__PURE__*/function (_LazyIterator10) {
_inheritsLoose(AsyncMapIterator, _LazyIterator10);
function AsyncMapIterator(upstream, transform) {
var _this15;
_this15 = _LazyIterator10.call(this) || this;
_this15.upstream = upstream;
_this15.transform = transform;
return _this15;
}
var _proto11 = AsyncMapIterator.prototype;
_proto11.summary = function summary() {
return this.upstream.summary() + " -> AsyncMap";
};
_proto11.next = /*#__PURE__*/function () {
var _next10 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee21() {
var item, inputTensors, mapped, outputTensors, _iterator2, _step2, t;
return regeneratorRuntime.wrap(function _callee21$(_context21) {
while (1) {
switch (_context21.prev = _context21.next) {
case 0:
_context21.next = 2;
return this.upstream.next();
case 2:
item = _context21.sent;
if (!item.done) {
_context21.next = 5;
break;
}
return _context21.abrupt("return", {
value: null,
done: true
});
case 5:
inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
// That's why we have to remember the input Tensors above, and then
// below dispose only those that were not passed through to the output.
// Note too that the transform function is responsible for tidying
// any intermediate Tensors. Here we are concerned only about the
// inputs.
_context21.next = 8;
return this.transform(item.value);
case 8:
mapped = _context21.sent;
outputTensors = getTensorsInContainer(mapped); // TODO(soergel) faster intersection
// TODO(soergel) move to tf.disposeExcept(in, out)?
for (_iterator2 = _createForOfIteratorHelperLoose(inputTensors); !(_step2 = _iterator2()).done;) {
t = _step2.value;
if (!isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return _context21.abrupt("return", {
value: mapped,
done: false
});
case 12:
case "end":
return _context21.stop();
}
}
}, _callee21, this);
}));
function next() {
return _next10.apply(this, arguments);
}
return next;
}();
return AsyncMapIterator;
}(LazyIterator); // Iterators that maintain a queue of pending items
// ============================================================================
/**
* A base class for transforming streams that operate by maintaining an
* output queue of elements that are ready to return via next(). This is
* commonly required when the transformation is 1-to-many: A call to next()
* may trigger a call to the underlying stream, which will produce many
* mapped elements of this stream-- of which we need to return only one, so
* we have to queue the rest.
*/
var OneToManyIterator = /*#__PURE__*/function (_LazyIterator11) {
_inheritsLoose(OneToManyIterator, _LazyIterator11);
function OneToManyIterator() {
var _this16;
_this16 = _LazyIterator11.call(this) || this;
_this16.outputQueue = new GrowingRingBuffer();
_this16.lastRead = Promise.resolve({
value: null,
done: false
});
return _this16;
}
var _proto12 = OneToManyIterator.prototype;
_proto12.next = /*#__PURE__*/function () {
var _next11 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee22() {
var _this17 = this;
return regeneratorRuntime.wrap(function _callee22$(_context22) {
while (1) {
switch (_context22.prev = _context22.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this17.serialNext();
});
return _context22.abrupt("return", this.lastRead);
case 2:
case "end":
return _context22.stop();
}
}
}, _callee22, this);
}));
function next() {
return _next11.apply(this, arguments);
}
return next;
}();
_proto12.serialNext = /*#__PURE__*/function () {
var _serialNext6 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee23() {
return regeneratorRuntime.wrap(function _callee23$(_context23) {
while (1) {
switch (_context23.prev = _context23.next) {
case 0:
if (!(this.outputQueue.length() === 0)) {
_context23.next = 7;
break;
}
_context23.next = 3;
return this.pump();
case 3:
if (_context23.sent) {
_context23.next = 5;
break;
}
return _context23.abrupt("return", {
value: null,
done: true
});
case 5:
_context23.next = 0;
break;
case 7:
return _context23.abrupt("return", {
value: this.outputQueue.shift(),
done: false
});
case 8:
case "end":
return _context23.stop();
}
}
}, _callee23, this);
}));
function serialNext() {
return _serialNext6.apply(this, arguments);
}
return serialNext;
}();
return OneToManyIterator;
}(LazyIterator);
var FlatmapIterator = /*#__PURE__*/function (_OneToManyIterator) {
_inheritsLoose(FlatmapIterator, _OneToManyIterator);
function FlatmapIterator(upstream, transform) {
var _this18;
_this18 = _OneToManyIterator.call(this) || this;
_this18.upstream = upstream;
_this18.transform = transform;
return _this18;
}
var _proto13 = FlatmapIterator.prototype;
_proto13.summary = function summary() {
return this.upstream.summary() + " -> Flatmap";
};
_proto13.pump = /*#__PURE__*/function () {
var _pump = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee24() {
var item, inputTensors, mappedArray, outputTensors, _iterator3, _step3, t;
return regeneratorRuntime.wrap(function _callee24$(_context24) {
while (1) {
switch (_context24.prev = _context24.next) {
case 0:
_context24.next = 2;
return this.upstream.next();
case 2:
item = _context24.sent;
if (!item.done) {
_context24.next = 5;
break;
}
return _context24.abrupt("return", false);
case 5:
inputTensors = getTensorsInContainer(item.value); // Careful: the transform may mutate the item in place.
// that's why we have to remember the input Tensors above, and then
// below dispose only those that were not passed through to the output.
// Note too that the transform function is responsible for tidying any
// intermediate Tensors. Here we are concerned only about the inputs.
mappedArray = this.transform(item.value);
outputTensors = getTensorsInContainer(mappedArray);
this.outputQueue.pushAll(mappedArray); // TODO(soergel) faster intersection, and deduplicate outputTensors
// TODO(soergel) move to tf.disposeExcept(in, out)?
for (_iterator3 = _createForOfIteratorHelperLoose(inputTensors); !(_step3 = _iterator3()).done;) {
t = _step3.value;
if (!isTensorInList(t, outputTensors)) {
t.dispose();
}
}
return _context24.abrupt("return", true);
case 11:
case "end":
return _context24.stop();
}
}
}, _callee24, this);
}));
function pump() {
return _pump.apply(this, arguments);
}
return pump;
}();
return FlatmapIterator;
}(OneToManyIterator);
/**
* Provides a `LazyIterator` that concatenates a stream of underlying
* streams.
*
* Doing this in a concurrency-safe way requires some trickery. In
* particular, we want this stream to return the elements from the
* underlying streams in the correct order according to when next() was
* called, even if the resulting Promises resolve in a different order.
*/
var ChainedIterator = /*#__PURE__*/function (_LazyIterator12) {
_inheritsLoose(ChainedIterator, _LazyIterator12);
function ChainedIterator(iterators, baseErrorHandler) {
var _this19;
_this19 = _LazyIterator12.call(this) || this;
_this19.baseErrorHandler = baseErrorHandler; // Strict Promise execution order:
// a next() call may not even begin until the previous one completes.
_this19.lastRead = null; // Local state that should not be clobbered by out-of-order execution.
_this19.iterator = null;
_this19.moreIterators = iterators;
return _this19;
}
var _proto14 = ChainedIterator.prototype;
_proto14.summary = function summary() {
var upstreamSummaries = 'TODO: fill in upstream of chained summaries';
return upstreamSummaries + " -> Chained";
};
_proto14.next = /*#__PURE__*/function () {
var _next12 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee25() {
return regeneratorRuntime.wrap(function _callee25$(_context25) {
while (1) {
switch (_context25.prev = _context25.next) {
case 0:
this.lastRead = this.readFromChain(this.lastRead);
return _context25.abrupt("return", this.lastRead);
case 2:
case "end":
return _context25.stop();
}
}
}, _callee25, this);
}));
function next() {
return _next12.apply(this, arguments);
}
return next;
}();
_proto14.readFromChain = /*#__PURE__*/function () {
var _readFromChain = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee26(lastRead) {
var iteratorResult, itemResult;
return regeneratorRuntime.wrap(function _callee26$(_context26) {
while (1) {
switch (_context26.prev = _context26.next) {
case 0:
_context26.next = 2;
return lastRead;
case 2:
if (!(this.iterator == null)) {
_context26.next = 10;
break;
}
_context26.next = 5;
return this.moreIterators.next();
case 5:
iteratorResult = _context26.sent;
if (!iteratorResult.done) {
_context26.next = 8;
break;
}
return _context26.abrupt("return", {
value: null,
done: true
});
case 8:
this.iterator = iteratorResult.value;
if (this.baseErrorHandler != null) {
this.iterator = this.iterator.handleErrors(this.baseErrorHandler);
}
case 10:
_context26.next = 12;
return this.iterator.next();
case 12:
itemResult = _context26.sent;
if (!itemResult.done) {
_context26.next = 16;
break;
}
this.iterator = null;
return _context26.abrupt("return", this.readFromChain(lastRead));
case 16:
return _context26.abrupt("return", itemResult);
case 17:
case "end":
return _context26.stop();
}
}
}, _callee26, this);
}));
function readFromChain(_x4) {
return _readFromChain.apply(this, arguments);
}
return readFromChain;
}();
return ChainedIterator;
}(LazyIterator);
var ZipMismatchMode;
(function (ZipMismatchMode) {
ZipMismatchMode[ZipMismatchMode["FAIL"] = 0] = "FAIL";
ZipMismatchMode[ZipMismatchMode["SHORTEST"] = 1] = "SHORTEST";
ZipMismatchMode[ZipMismatchMode["LONGEST"] = 2] = "LONGEST"; // use nulls for exhausted streams; use up the longest stream.
})(ZipMismatchMode || (ZipMismatchMode = {}));
/**
* Provides a `LazyIterator` that zips together an array, dict, or nested
* structure of `LazyIterator`s (and perhaps additional constants).
*
* The underlying streams must provide elements in a consistent order such
* that they correspond.
*
* Typically, the underlying streams should have the same number of
* elements. If they do not, the behavior is determined by the
* `mismatchMode` argument.
*
* The nested structure of the `iterators` argument determines the
* structure of elements in the resulting iterator.
*
* Doing this in a concurrency-safe way requires some trickery. In
* particular, we want this stream to return the elements from the
* underlying streams in the correct order according to when next() was
* called, even if the resulting Promises resolve in a different order.
*
* @param iterators: An array or object containing LazyIterators at the
* leaves.
* @param mismatchMode: Determines what to do when one underlying iterator
* is exhausted before the others. `ZipMismatchMode.FAIL` (the default)
* causes an error to be thrown in this case. `ZipMismatchMode.SHORTEST`
* causes the zipped iterator to terminate with the furst underlying
* streams, so elements remaining on the longer streams are ignored.
* `ZipMismatchMode.LONGEST` causes the zipped stream to continue, filling
* in nulls for the exhausted streams, until all streams are exhausted.
*/
var ZipIterator = /*#__PURE__*/function (_LazyIterator13) {
_inheritsLoose(ZipIterator, _LazyIterator13);
function ZipIterator(iterators, mismatchMode) {
var _this20;
if (mismatchMode === void 0) {
mismatchMode = ZipMismatchMode.FAIL;
}
_this20 = _LazyIterator13.call(this) || this;
_this20.iterators = iterators;
_this20.mismatchMode = mismatchMode;
_this20.count = 0;
_this20.currentPromise = null;
return _this20;
}
var _proto15 = ZipIterator.prototype;
_proto15.summary = function summary() {
var upstreamSummaries = 'TODO: fill in upstream of zip summaries';
return "{" + upstreamSummaries + "} -> Zip";
};
_proto15.nextState = /*#__PURE__*/function () {
var _nextState = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee27(afterState) {
var numIterators, iteratorsDone, getNext, mapped;
return regeneratorRuntime.wrap(function _callee27$(_context27) {
while (1) {
switch (_context27.prev = _context27.next) {
case 0:
getNext = function _getNext(container) {
if (container instanceof LazyIterator) {
var result = container.next();
return {
value: result.then(function (x) {
numIterators++;
if (x.done) {
iteratorsDone++;
}
return x.value;
}),
recurse: false
};
} else {
return {
value: null,
recurse: true
};
}
};
_context27.next = 3;
return afterState;
case 3:
// Collect underlying iterator "done" signals as a side effect in
// getNext()
numIterators = 0;
iteratorsDone = 0;
_context27.next = 7;
return deepMapAndAwaitAll(this.iterators, getNext);
case 7:
mapped = _context27.sent;
if (!(numIterators === iteratorsDone)) {
_context27.next = 10;
break;
}
return _context27.abrupt("return", {
value: null,
done: true
});
case 10:
if (!(iteratorsDone > 0)) {
_context27.next = 16;
break;
}
_context27.t0 = this.mismatchMode;
_context27.next = _context27.t0 === ZipMismatchMode.FAIL ? 14 : _context27.t0 === ZipMismatchMode.SHORTEST ? 15 : _context27.t0 === ZipMismatchMode.LONGEST ? 16 : 16;
break;
case 14:
throw new Error('Zipped streams should have the same length. ' + ("Mismatched at element " + this.count + "."));
case 15:
return _context27.abrupt("return", {
value: null,
done: true
});
case 16:
this.count++;
return _context27.abrupt("return", {
value: mapped,
done: false
});
case 18:
case "end":
return _context27.stop();
}
}
}, _callee27, this);
}));
function nextState(_x5) {
return _nextState.apply(this, arguments);
}
return nextState;
}();
_proto15.next = /*#__PURE__*/function () {
var _next13 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee28() {
return regeneratorRuntime.wrap(function _callee28$(_context28) {
while (1) {
switch (_context28.prev = _context28.next) {
case 0:
this.currentPromise = this.nextState(this.currentPromise);
return _context28.abrupt("return", this.currentPromise);
case 2:
case "end":
return _context28.stop();
}
}
}, _callee28, this);
}));
function next() {
return _next13.apply(this, arguments);
}
return next;
}();
return ZipIterator;
}(LazyIterator); // Iterators that maintain a ring buffer of pending promises
// ============================================================================
/**
* A stream that prefetches a given number of items from an upstream source,
* returning them in FIFO order.
*
* Note this prefetches Promises, but makes no guarantees about when those
* Promises resolve.
*/
var PrefetchIterator = /*#__PURE__*/function (_LazyIterator14) {
_inheritsLoose(PrefetchIterator, _LazyIterator14);
function PrefetchIterator(upstream, bufferSize) {
var _this21;
_this21 = _LazyIterator14.call(this) || this;
_this21.upstream = upstream;
_this21.bufferSize = bufferSize;
_this21.buffer = new RingBuffer(bufferSize);
return _this21;
}
var _proto16 = PrefetchIterator.prototype;
_proto16.summary = function summary() {
return this.upstream.summary() + " -> Prefetch";
}
/**
* Refill the prefetch buffer. Returns only after the buffer is full, or
* the upstream source is exhausted.
*/
;
_proto16.refill = function refill() {
while (!this.buffer.isFull()) {
var v = this.upstream.next();
this.buffer.push(v);
}
};
_proto16.next = function next() {
this.refill(); // This shift will never throw an error because the buffer is always
// full after a refill. If the stream is exhausted, the buffer will be
// full of Promises that will resolve to the end-of-stream signal.
return this.buffer.shift();
};
return PrefetchIterator;
}(LazyIterator);
/**
* A stream that performs a sliding-window random shuffle on an upstream
* source. This is like a `PrefetchIterator` except that the items are
* returned in randomized order. Mixing naturally improves as the buffer
* size increases.
*/
var ShuffleIterator = /*#__PURE__*/function (_PrefetchIterator) {
_inheritsLoose(ShuffleIterator, _PrefetchIterator);
function ShuffleIterator(upstream, windowSize, seed) {
var _this22;
_this22 = _PrefetchIterator.call(this, upstream, windowSize) || this;
_this22.upstream = upstream;
_this22.windowSize = windowSize; // Local state that should not be clobbered by out-of-order execution.
_this22.upstreamExhausted = false;
_this22.random = seedrandom_1(seed || now().toString());
_this22.lastRead = Promise.resolve({
value: null,
done: false
});
return _this22;
}
var _proto17 = ShuffleIterator.prototype;
_proto17.next = /*#__PURE__*/function () {
var _next14 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee29() {
var _this23 = this;
return regeneratorRuntime.wrap(function _callee29$(_context29) {
while (1) {
switch (_context29.prev = _context29.next) {
case 0:
// This sets this.lastRead to a new Promise right away, as opposed to
// saying `await this.lastRead; this.lastRead = this.serialNext();` which
// would not work because this.nextRead would be updated only after the
// promise resolves.
this.lastRead = this.lastRead.then(function () {
return _this23.serialNext();
});
return _context29.abrupt("return", this.lastRead);
case 2:
case "end":
return _context29.stop();
}
}
}, _callee29, this);
}));
function next() {
return _next14.apply(this, arguments);
}
return next;
}();
_proto17.randomInt = function randomInt(max) {
return Math.floor(this.random() * max);
};
_proto17.chooseIndex = function chooseIndex() {
return this.randomInt(this.buffer.length());
};
_proto17.serialNext = /*#__PURE__*/function () {
var _serialNext7 = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee30() {
var chosenIndex, result;
return regeneratorRuntime.wrap(function _callee30$(_context30) {
while (1) {
switch (_context30.prev = _context30.next) {
case 0:
// TODO(soergel): consider performance
if (!this.upstreamExhausted) {
this.refill();
}
case 1:
if (this.buffer.isEmpty()) {
_context30.next = 14;
break;
}
chosenIndex = this.chooseIndex();
_context30.next = 5;
return this.buffer.shuffleExcise(chosenIndex);
case 5:
result = _context30.sent;
if (!result.done) {
_context30.next = 10;
break;
}
this.upstreamExhausted = true;
_context30.next = 12;
break;
case 10:
this.refill();
return _context30.abrupt("return", result);
case 12:
_context30.next = 1;
break;
case 14:
return _context30.abrupt("return", {
value: null,
done: true
});
case 15:
case "end":
return _context30.stop();
}
}
}, _callee30, this);
}));
function serialNext() {
return _serialNext7.apply(this, arguments);
}
return serialNext;
}();
return ShuffleIterator;
}(PrefetchIterator);
/**
* Represents a potentially large list of independent data elements (typically
* 'samples' or 'examples').
*
* A 'data example' may be a primitive, an array, a map from string keys to
* values, or any nested structure of these.
*
* A `Dataset` represents an ordered collection of elements, together with a
* chain of transformations to be performed on those elements. Each
* transformation is a method of `Dataset` that returns another `Dataset`, so
* these may be chained, e.g.
* `const processedDataset = rawDataset.filter(...).map(...).batch(...)`.
*
* Data loading and transformation is done in a lazy, streaming fashion. The
* dataset may be iterated over multiple times; each iteration starts the data
* loading anew and recapitulates the transformations.
*
* A `Dataset` is typically processed as a stream of unbatched examples --i.e.,
* its transformations are applied one example at a time. Batching produces a
* new `Dataset` where each element is a batch. Batching should usually come
* last in a pipeline, because data transformations are easier to express on a
* per-example basis than on a per-batch basis.
*
* The following code examples are calling `await dataset.forEachAsync(...)` to
* iterate once over the entire dataset in order to print out the data.
*
* @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
*/
var Dataset = /*#__PURE__*/function () {
function Dataset() {
this.size = null;
} // TODO(soergel): Make Datasets report whether repeated iterator() calls
// produce the same result (e.g., reading from a file) or different results
// (e.g., from the webcam). Currently we don't make this distinction but it
// could be important for the user to know.
// abstract isDeterministic(): boolean;
/**
* Groups elements into batches.
*
* It is assumed that each of the incoming dataset elements has the same
* structure-- i.e. the same set of keys at each location in an object
* hierarchy. For each key, the resulting `Dataset` provides a batched
* element collecting all of the incoming values for that key.
*
* * Incoming primitives are grouped into a 1-D Tensor.
* * Incoming Tensors are grouped into a new Tensor where the 0'th axis is
* the batch dimension.
* * Incoming arrays are converted to Tensor and then batched.
* * A nested array is interpreted as an n-D Tensor, so the batched result
* has n+1 dimensions.
* * An array that cannot be converted to Tensor produces an error.
*
* If an array should not be batched as a unit, it should first be converted
* to an object with integer keys.
*
* Here are a few examples:
*
* Batch a dataset of numbers:
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
* await a.forEachAsync(e => e.print());
* ```
*
* Batch a dataset of arrays:
* ```js
* const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
* await b.forEachAsync(e => e.print());
* ```
*
* Batch a dataset of objects:
* ```js
* const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
* {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
* {a: 8, b: 18}]).batch(4);
* await c.forEachAsync(e => {
* console.log('{');
* for(var key in e) {
* console.log(key+':');
* e[key].print();
* }
* console.log('}');
* })
* ```
*
* @param batchSize The number of elements desired per batch.
* @param smallLastBatch Whether to emit the final batch when it has fewer
* than batchSize elements. Default true.
* @returns A `Dataset`, from which a stream of batches can be obtained.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
var _proto = Dataset.prototype;
_proto.batch = function batch(batchSize, smallLastBatch) {
if (smallLastBatch === void 0) {
smallLastBatch = true;
}
var base = this;
assert(batchSize > 0, function () {
return "batchSize needs to be positive, but it is\n " + batchSize;
});
var size;
if (this.size === Infinity || this.size == null) {
// If the size of this dataset is infinity or null, the new size keeps the
// same.
size = this.size;
} else if (smallLastBatch) {
// If the size of this dataset is known and include small last batch, the
// new size is full batch count plus last batch.
size = Math.ceil(this.size / batchSize);
} else {
// If the size of this dataset is known and not include small last batch,
// the new size is full batch count.
size = Math.floor(this.size / batchSize);
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return base.iterator();
case 2:
return _context.abrupt("return", _context.sent.columnMajorBatch(batchSize, smallLastBatch, deepBatchConcat));
case 3:
case "end":
return _context.stop();
}
}
}, _callee);
})), size);
}
/**
* Concatenates this `Dataset` with another.
*
* ```js
* const a = tf.data.array([1, 2, 3]);
* const b = tf.data.array([4, 5, 6]);
* const c = a.concatenate(b);
* await c.forEachAsync(e => console.log(e));
* ```
*
* @param dataset A `Dataset` to be concatenated onto this one.
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.concatenate = function concatenate(dataset) {
var base = this;
var size;
if (this.size === Infinity || dataset.size === Infinity) {
// If the size of any of these two dataset is infinity, new size is
// infinity.
size = Infinity;
} else if (this.size != null && dataset.size != null) {
// If the size of both datasets are known and not infinity, new size is
// sum the size of these two datasets.
size = this.size + dataset.size;
} else {
// If neither of these two datasets has infinite size and any of these two
// datasets' size is null, the new size is null.
size = null;
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return base.iterator();
case 2:
_context2.t0 = _context2.sent;
_context2.next = 5;
return dataset.iterator();
case 5:
_context2.t1 = _context2.sent;
return _context2.abrupt("return", _context2.t0.concatenate.call(_context2.t0, _context2.t1));
case 7:
case "end":
return _context2.stop();
}
}
}, _callee2);
})), size);
}
/**
* Filters this dataset according to `predicate`.
*
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
* .filter(x => x%2 === 0);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param predicate A function mapping a dataset element to a boolean or a
* `Promise` for one.
*
* @returns A `Dataset` of elements for which the predicate was true.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.filter = function filter(predicate) {
var base = this;
var size;
if (this.size === Infinity) {
// If the size of this dataset is infinity, new size is infinity
size = Infinity;
} else {
// If this dataset has limited elements, new size is null because it might
// exhausted randomly.
size = null;
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
_context3.next = 2;
return base.iterator();
case 2:
return _context3.abrupt("return", _context3.sent.filter(function (x) {
return tidy(function () {
return predicate(x);
});
}));
case 3:
case "end":
return _context3.stop();
}
}
}, _callee3);
})), size);
}
/**
* Apply a function to every element of the dataset.
*
* After the function is applied to a dataset element, any Tensors contained
* within that element are disposed.
*
* ```js
* const a = tf.data.array([1, 2, 3]);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param f A function to apply to each dataset element.
* @returns A `Promise` that resolves after all elements have been processed.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.forEachAsync =
/*#__PURE__*/
function () {
var _forEachAsync = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(f) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
_context4.next = 2;
return this.iterator();
case 2:
return _context4.abrupt("return", _context4.sent.forEachAsync(f));
case 3:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function forEachAsync(_x) {
return _forEachAsync.apply(this, arguments);
}
return forEachAsync;
}()
/**
* Maps this dataset through a 1-to-1 transform.
*
* ```js
* const a = tf.data.array([1, 2, 3]).map(x => x*x);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param transform A function mapping a dataset element to a transformed
* dataset element.
*
* @returns A `Dataset` of transformed elements.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.map = function map(transform) {
var base = this;
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5() {
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
_context5.next = 2;
return base.iterator();
case 2:
return _context5.abrupt("return", _context5.sent.map(function (x) {
return tidy(function () {
return transform(x);
});
}));
case 3:
case "end":
return _context5.stop();
}
}
}, _callee5);
})), this.size);
}
/**
* Maps this dataset through an async 1-to-1 transform.
*
* ```js
* const a =
* tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
* setTimeout(() => {
* resolve(x * x);
* }, Math.random()*1000 + 500);
* }));
* console.log(await a.toArray());
* ```
*
* @param transform A function mapping a dataset element to a `Promise` for a
* transformed dataset element. This transform is responsible for disposing
* any intermediate `Tensor`s, i.e. by wrapping its computation in
* `tf.tidy()`; that cannot be automated here (as it is in the synchronous
* `map()` case).
*
* @returns A `Dataset` of transformed elements.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.mapAsync = function mapAsync(transform) {
var base = this;
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee6() {
return regeneratorRuntime.wrap(function _callee6$(_context6) {
while (1) {
switch (_context6.prev = _context6.next) {
case 0:
_context6.next = 2;
return base.iterator();
case 2:
return _context6.abrupt("return", _context6.sent.mapAsync(transform));
case 3:
case "end":
return _context6.stop();
}
}
}, _callee6);
})), this.size);
}
/**
* Creates a `Dataset` that prefetches elements from this dataset.
*
* @param bufferSize: An integer specifying the number of elements to be
* prefetched.
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.prefetch = function prefetch(bufferSize) {
if (bufferSize == null) {
throw new RangeError('`Dataset.prefetch()` requires bufferSize to be specified.');
}
var base = this;
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee7() {
return regeneratorRuntime.wrap(function _callee7$(_context7) {
while (1) {
switch (_context7.prev = _context7.next) {
case 0:
_context7.next = 2;
return base.iterator();
case 2:
return _context7.abrupt("return", _context7.sent.prefetch(bufferSize));
case 3:
case "end":
return _context7.stop();
}
}
}, _callee7);
})), this.size);
}
/**
* Repeats this dataset `count` times.
*
* NOTE: If this dataset is a function of global state (e.g. a random number
* generator), then different repetitions may produce different elements.
*
* ```js
* const a = tf.data.array([1, 2, 3]).repeat(3);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param count: (Optional) An integer, representing the number of times
* the dataset should be repeated. The default behavior (if `count` is
* `undefined` or negative) is for the dataset be repeated indefinitely.
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.repeat = function repeat(count) {
var base = this;
var size;
if (this.size != null && count > 0) {
// If this dataset has size and count is positive, new size is current
// size multiply count. This also covers the case that current size is
// infinity.
size = this.size * count;
} else if (count === 0) {
// If count is 0, new size is 0.
size = 0;
} else if (this.size != null && (count === undefined || count < 0)) {
// If this dataset has size and count is undefined or negative, the
// dataset will be repeated indefinitely and new size is infinity.
size = Infinity;
} else {
// If the size of this dataset is null, the new dataset's size is null.
size = null;
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee9() {
var iteratorIterator;
return regeneratorRuntime.wrap(function _callee9$(_context9) {
while (1) {
switch (_context9.prev = _context9.next) {
case 0:
iteratorIterator = iteratorFromFunction( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee8() {
return regeneratorRuntime.wrap(function _callee8$(_context8) {
while (1) {
switch (_context8.prev = _context8.next) {
case 0:
_context8.next = 2;
return base.iterator();
case 2:
_context8.t0 = _context8.sent;
return _context8.abrupt("return", {
value: _context8.t0,
done: false
});
case 4:
case "end":
return _context8.stop();
}
}
}, _callee8);
})));
return _context9.abrupt("return", iteratorFromConcatenated(iteratorIterator.take(count)));
case 2:
case "end":
return _context9.stop();
}
}
}, _callee9);
})), size);
}
/**
* Creates a `Dataset` that skips `count` initial elements from this dataset.
*
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param count: The number of elements of this dataset that should be skipped
* to form the new dataset. If `count` is greater than the size of this
* dataset, the new dataset will contain no elements. If `count`
* is `undefined` or negative, skips the entire dataset.
*
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.skip = function skip(count) {
var base = this;
var size;
if (this.size != null && count >= 0 && this.size >= count) {
// If the size of this dataset is greater than count, the new dataset's
// size is current size minus skipped size.This also covers the case that
// current size is infinity.
size = this.size - count;
} else if (this.size != null && (this.size < count || count === undefined || count < 0)) {
// If the size of this dataset is smaller than count, or count is
// undefined or negative, skips the entire dataset and the new size is 0.
size = 0;
} else {
// If the size of this dataset is null, the new dataset's size is null.
size = null;
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee10() {
return regeneratorRuntime.wrap(function _callee10$(_context10) {
while (1) {
switch (_context10.prev = _context10.next) {
case 0:
_context10.next = 2;
return base.iterator();
case 2:
return _context10.abrupt("return", _context10.sent.skip(count));
case 3:
case "end":
return _context10.stop();
}
}
}, _callee10);
})), size);
}
/**
* Pseudorandomly shuffles the elements of this dataset. This is done in a
* streaming manner, by sampling from a given number of prefetched elements.
*
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param bufferSize: An integer specifying the number of elements from this
* dataset from which the new dataset will sample.
* @param seed: (Optional) An integer specifying the random seed that will
* be used to create the distribution.
* @param reshuffleEachIteration: (Optional) A boolean, which if true
* indicates that the dataset should be pseudorandomly reshuffled each time
* it is iterated over. If false, elements will be returned in the same
* shuffled order on each iteration. (Defaults to `true`.)
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.shuffle = function shuffle(bufferSize, seed, reshuffleEachIteration) {
if (reshuffleEachIteration === void 0) {
reshuffleEachIteration = true;
}
if (bufferSize == null || bufferSize < 0) {
if (this.size == null) {
throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified.');
} else {
throw new RangeError('`Dataset.shuffle()` requires bufferSize to be specified. ' + 'If your data fits in main memory (for regular JS objects), ' + 'and/or GPU memory (for `tf.Tensor`s), consider setting ' + ("bufferSize to the dataset size (" + this.size + " elements)"));
}
}
var base = this;
var random = seedrandom_1(seed || now().toString());
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee11() {
var seed2;
return regeneratorRuntime.wrap(function _callee11$(_context11) {
while (1) {
switch (_context11.prev = _context11.next) {
case 0:
seed2 = random.int32();
if (reshuffleEachIteration) {
seed2 += random.int32();
}
_context11.next = 4;
return base.iterator();
case 4:
return _context11.abrupt("return", _context11.sent.shuffle(bufferSize, seed2.toString()));
case 5:
case "end":
return _context11.stop();
}
}
}, _callee11);
})), this.size);
}
/**
* Creates a `Dataset` with at most `count` initial elements from this
* dataset.
*
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
* await a.forEachAsync(e => console.log(e));
* ```
*
* @param count: The number of elements of this dataset that should be taken
* to form the new dataset. If `count` is `undefined` or negative, or if
* `count` is greater than the size of this dataset, the new dataset will
* contain all elements of this dataset.
* @returns A `Dataset`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.take = function take(count) {
var base = this;
var size;
if (this.size != null && this.size > count) {
// If the size of this dataset is greater than count, the new dataset's
// size is count.
size = count;
} else if (this.size != null && this.size <= count) {
// If the size of this dataset is equal or smaller than count, the new
// dataset's size is the size of this dataset.
size = this.size;
} else {
// If the size of this dataset is null, the new dataset's size is null.
size = null;
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee12() {
return regeneratorRuntime.wrap(function _callee12$(_context12) {
while (1) {
switch (_context12.prev = _context12.next) {
case 0:
_context12.next = 2;
return base.iterator();
case 2:
return _context12.abrupt("return", _context12.sent.take(count));
case 3:
case "end":
return _context12.stop();
}
}
}, _callee12);
})), size);
}
/**
* Collect all elements of this dataset into an array.
*
* Obviously this will succeed only for small datasets that fit in memory.
* Useful for testing and generally should be avoided if possible.
*
* ```js
* const a = tf.data.array([1, 2, 3, 4, 5, 6]);
* console.log(await a.toArray());
* ```
*
* @returns A Promise for an array of elements, which will resolve
* when a new stream has been obtained and fully consumed.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
;
_proto.toArray =
/*#__PURE__*/
function () {
var _toArray = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee13() {
return regeneratorRuntime.wrap(function _callee13$(_context13) {
while (1) {
switch (_context13.prev = _context13.next) {
case 0:
if (!(this.size === Infinity)) {
_context13.next = 2;
break;
}
throw new Error('Can not convert infinite data stream to array.');
case 2:
_context13.next = 4;
return this.iterator();
case 4:
return _context13.abrupt("return", _context13.sent.toArray());
case 5:
case "end":
return _context13.stop();
}
}
}, _callee13, this);
}));
function toArray() {
return _toArray.apply(this, arguments);
}
return toArray;
}()
/**
* Collect all elements of this dataset into an array with prefetching 100
* elements. This is useful for testing, because the prefetch changes the
* order in which the Promises are resolved along the processing pipeline.
* This may help expose bugs where results are dependent on the order of
* Promise resolution rather than on the logical order of the stream (i.e.,
* due to hidden mutable state).
*
* @returns A Promise for an array of elements, which will resolve
* when a new stream has been obtained and fully consumed.
*/
;
_proto.toArrayForTest =
/*#__PURE__*/
function () {
var _toArrayForTest = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee14() {
return regeneratorRuntime.wrap(function _callee14$(_context14) {
while (1) {
switch (_context14.prev = _context14.next) {
case 0:
if (!(this.size === Infinity)) {
_context14.next = 2;
break;
}
throw new Error('Can not convert infinite data stream to array.');
case 2:
_context14.next = 4;
return this.iterator();
case 4:
return _context14.abrupt("return", _context14.sent.toArrayForTest());
case 5:
case "end":
return _context14.stop();
}
}
}, _callee14, this);
}));
function toArrayForTest() {
return _toArrayForTest.apply(this, arguments);
}
return toArrayForTest;
}();
return Dataset;
}(); // TODO(soergel): deep sharded shuffle, where supported
Dataset.MAX_BUFFER_SIZE = 10000;
/**
* Create a `Dataset` defined by a provided iterator() function.
*
* ```js
* let i = -1;
* const func = () =>
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
* const iter = tf.data.iteratorFromFunction(func);
* const ds = tf.data.datasetFromIteratorFn(iter);
* await ds.forEachAsync(e => console.log(e));
* ```
*/
function datasetFromIteratorFn(iteratorFn, size) {
if (size === void 0) {
size = null;
}
return new ( /*#__PURE__*/function (_Dataset) {
_inheritsLoose(_class, _Dataset);
function _class() {
var _this;
_this = _Dataset.apply(this, arguments) || this;
_this.size = size;
return _this;
}
/*
* Provide a new stream of elements. Note this will also start new streams
* from any underlying `Dataset`s.
*/
var _proto2 = _class.prototype;
_proto2.iterator =
/*#__PURE__*/
function () {
var _iterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee15() {
return regeneratorRuntime.wrap(function _callee15$(_context15) {
while (1) {
switch (_context15.prev = _context15.next) {
case 0:
return _context15.abrupt("return", iteratorFn());
case 1:
case "end":
return _context15.stop();
}
}
}, _callee15);
}));
function iterator() {
return _iterator.apply(this, arguments);
}
return iterator;
}();
return _class;
}(Dataset))();
}
/**
* Create a `Dataset` from an array of elements.
*
* Create a Dataset from an array of objects:
* ```js
* const a = tf.data.array([{'item': 1}, {'item': 2}, {'item': 3}]);
* await a.forEachAsync(e => console.log(e));
* ```
*
* Create a Dataset from an array of numbers:
* ```js
* const a = tf.data.array([4, 5, 6]);
* await a.forEachAsync(e => console.log(e));
* ```
* @param items An array of elements that will be parsed as items in a dataset.
*
* @doc {heading: 'Data', subheading: 'Creation', namespace: 'data'}
*/
function array(items) {
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee16() {
return regeneratorRuntime.wrap(function _callee16$(_context16) {
while (1) {
switch (_context16.prev = _context16.next) {
case 0:
return _context16.abrupt("return", iteratorFromItems(items));
case 1:
case "end":
return _context16.stop();
}
}
}, _callee16);
})), items.length);
}
/**
* Create a `Dataset` by zipping together an array, dict, or nested
* structure of `Dataset`s (and perhaps additional constants).
* The underlying datasets must provide elements in a consistent order such that
* they correspond.
*
* The number of elements in the resulting dataset is the same as the size of
* the smallest dataset in datasets.
*
* The nested structure of the `datasets` argument determines the
* structure of elements in the resulting iterator.
*
* Note this means that, given an array of two datasets that produce dict
* elements, the result is a dataset that produces elements that are arrays
* of two dicts:
*
* Zip an array of datasets:
* ```js
* console.log('Zip two datasets of objects:');
* const ds1 = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
* const ds2 = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
* const ds3 = tf.data.zip([ds1, ds2]);
* await ds3.forEachAsync(e => console.log(JSON.stringify(e)));
*
* // If the goal is to merge the dicts in order to produce elements like
* // {a: ..., b: ...}, this requires a second step such as:
* console.log('Merge the objects:');
* const ds4 = ds3.map(x => {return {a: x[0].a, b: x[1].b}});
* await ds4.forEachAsync(e => console.log(e));
* ```
*
* Zip a dict of datasets:
* ```js
* const a = tf.data.array([{a: 1}, {a: 2}, {a: 3}]);
* const b = tf.data.array([{b: 4}, {b: 5}, {b: 6}]);
* const c = tf.data.zip({c: a, d: b});
* await c.forEachAsync(e => console.log(JSON.stringify(e)));
* ```
*
* @doc {heading: 'Data', subheading: 'Operations', namespace: 'data'}
*/
function zip(datasets) {
// manually type-check the argument for JS users
if (!isIterable$1(datasets)) {
throw new Error('The argument to zip() must be an object or array.');
}
var size;
if (Array.isArray(datasets)) {
for (var i = 0; i < datasets.length; i++) {
size = size == null ? datasets[i].size : Math.min(size, datasets[i].size);
}
} else if (datasets instanceof Object) {
for (var ds in datasets) {
size = size == null ? datasets[ds].size : Math.min(size, datasets[ds].size);
}
}
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee17() {
var streams;
return regeneratorRuntime.wrap(function _callee17$(_context17) {
while (1) {
switch (_context17.prev = _context17.next) {
case 0:
_context17.next = 2;
return deepMapAndAwaitAll(datasets, function (d) {
if (d instanceof Dataset) {
return {
value: d.iterator(),
recurse: false
};
} else if (isIterable$1(d)) {
return {
value: null,
recurse: true
};
} else {
throw new Error('Leaves of the structure passed to zip() must be Datasets, ' + 'not primitives.');
}
});
case 2:
streams = _context17.sent;
return _context17.abrupt("return", iteratorFromZipped(streams, ZipMismatchMode.SHORTEST));
case 4:
case "end":
return _context17.stop();
}
}
}, _callee17);
})), size);
}
/**
* A zip function for use with deepZip, passed via the columnMajorBatch call.
*
* Accepts an array of identically-structured nested elements and either batches
* them (if they are primitives, numeric arrays, or Tensors) or requests
* recursion (if not).
*/
// tslint:disable-next-line:no-any
function deepBatchConcat(rows) {
if (rows === null) {
return null;
} // use the first item to decide whether to recurse or batch here.
var exampleRow = rows[0];
if (canTensorify(exampleRow)) {
// rows is an array of primitives, Tensors, or arrays. Batch them.
var value = batchConcat(rows);
return {
value: value,
recurse: false
};
} // the example row is an object, so recurse into it.
return {
value: null,
recurse: true
};
}
/**
* Assembles a list of same-shaped numbers, number arrays, or Tensors
* into a single new Tensor where axis 0 is the batch dimension.
*/
function batchConcat(arrays) {
if (arrays.length === 0) {
// We can't return an empty Tensor because we don't know the element shape.
throw new Error('Can\'t make a batch of zero elements.');
}
if (arrays[0] instanceof Tensor) {
// Input is an array of Tensors
return stack(arrays);
} else {
// Input is a possibly-nested array of numbers.
return tensor(arrays);
}
}
/**
* Represents a potentially large collection of text lines.
*
* The results are not batched.
*/
var TextLineDataset = /*#__PURE__*/function (_Dataset) {
_inheritsLoose(TextLineDataset, _Dataset);
/**
* Create a `TextLineDataset`.
*
* @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
*/
function TextLineDataset(input) {
var _this;
_this = _Dataset.call(this) || this;
_this.input = input;
return _this;
}
var _proto = TextLineDataset.prototype;
_proto.iterator = /*#__PURE__*/function () {
var _iterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var inputIterator, utf8Iterator, lineIterator;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return this.input.iterator();
case 2:
inputIterator = _context.sent;
utf8Iterator = inputIterator.decodeUTF8();
lineIterator = utf8Iterator.split('\n').map(function (line) {
// Windows/DOS format text file has extra line breaker at the end of line.
if (line.endsWith('\r')) {
line = line.slice(0, -1);
}
return line;
});
return _context.abrupt("return", lineIterator);
case 6:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function iterator() {
return _iterator.apply(this, arguments);
}
return iterator;
}();
return TextLineDataset;
}(Dataset);
var CODE_QUOTE = '"';
var STATE_OUT = Symbol('out');
var STATE_FIELD = Symbol('field');
var STATE_QUOTE = Symbol('quote');
var STATE_QUOTE_AFTER_QUOTE = Symbol('quoteafterquote');
var STATE_WITHIN_QUOTE_IN_QUOTE = Symbol('quoteinquote');
/**
* Represents a potentially large collection of delimited text records.
*
* The produced `TensorContainer`s each contain one key-value pair for
* every column of the table. When a field is empty in the incoming data, the
* resulting value is `undefined`, or throw error if it is required. Values
* that can be parsed as numbers are emitted as type `number`, other values
* are parsed as `string`.
*
* The results are not batched.
*
* @doc {heading: 'Data', subheading: 'Classes', namespace: 'data'}
*/
var CSVDataset = /*#__PURE__*/function (_Dataset) {
_inheritsLoose(CSVDataset, _Dataset);
/**
* Create a `CSVDataset`.
*
* @param input A `DataSource` providing a chunked, UTF8-encoded byte stream.
* @param csvConfig (Optional) A CSVConfig object that contains configurations
* of reading and decoding from CSV file(s).
*
* hasHeader: (Optional) A boolean value that indicates whether the first
* row of provided CSV file is a header line with column names, and should
* not be included in the data. Defaults to `true`.
*
* columnNames: (Optional) A list of strings that corresponds to
* the CSV column names, in order. If provided, it ignores the column
* names inferred from the header row. If not provided, infers the column
* names from the first row of the records. If hasHeader is false and
* columnNames is not provided, this method throws an error.
*
* columnConfigs: (Optional) A dictionary whose key is column names, value
* is an object stating if this column is required, column's data type,
* default value, and if this column is label. If provided, keys must
* correspond to names provided in columnNames or inferred from the file
* header lines. If isLabel is true any column, returns an array of two
* items: the first item is a dict of features key/value pairs, the second
* item is a dict of labels key/value pairs. If no feature is marked as
* label, returns a dict of features only.
*
* configuredColumnsOnly (Optional) If true, only columns provided in
* columnConfigs will be parsed and provided during iteration.
*
* delimiter (Optional) The string used to parse each line of the input
* file. Defaults to `,`.
*/
function CSVDataset(input, csvConfig) {
var _this;
_this = _Dataset.call(this) || this;
_this.input = input;
_this.hasHeader = true;
_this.fullColumnNames = null;
_this.columnNamesValidated = false;
_this.columnConfigs = null;
_this.configuredColumnsOnly = false;
_this.delimiter = ',';
_this.delimWhitespace = false;
_this.base = new TextLineDataset(input);
if (!csvConfig) {
csvConfig = {};
}
_this.hasHeader = csvConfig.hasHeader === false ? false : true;
_this.fullColumnNames = csvConfig.columnNames;
_this.columnConfigs = csvConfig.columnConfigs;
_this.configuredColumnsOnly = csvConfig.configuredColumnsOnly;
if (csvConfig.delimWhitespace) {
assert(csvConfig.delimiter == null, function () {
return 'Delimiter should not be provided when delimWhitespace is true.';
});
_this.delimWhitespace = true;
_this.delimiter = ' ';
} else {
_this.delimiter = csvConfig.delimiter ? csvConfig.delimiter : ',';
}
return _this;
}
/**
* Returns column names of the csv dataset. If `configuredColumnsOnly` is
* true, return column names in `columnConfigs`. If `configuredColumnsOnly` is
* false and `columnNames` is provided, `columnNames`. If
* `configuredColumnsOnly` is false and `columnNames` is not provided, return
* all column names parsed from the csv file. For example usage please go to
* `tf.data.csv`.
*
* @doc {heading: 'Data', subheading: 'Classes'}
*/
var _proto = CSVDataset.prototype;
_proto.columnNames =
/*#__PURE__*/
function () {
var _columnNames = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (this.columnNamesValidated) {
_context.next = 3;
break;
}
_context.next = 3;
return this.setColumnNames();
case 3:
return _context.abrupt("return", this.configuredColumnsOnly ? Object.keys(this.columnConfigs) : this.fullColumnNames);
case 4:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function columnNames() {
return _columnNames.apply(this, arguments);
}
return columnNames;
}()
/* 1) If `columnNames` is provided as string[], use this string[] as output
* keys in corresponding order. The length must match the number of inferred
* columns if `hasHeader` is true .
* 2) If `columnNames` is not provided, parse header line as `columnNames` if
* hasHeader is true. If `hasHeader` is false, throw an error.
* 3) If `columnConfigs` is provided, all the keys in `columnConfigs` must
* exist in parsed `columnNames`.
*/
;
_proto.setColumnNames =
/*#__PURE__*/
function () {
var _setColumnNames = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var _this2 = this;
var columnNamesFromFile, counts, duplicateNames, _i, _Object$keys, key, index;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.maybeReadHeaderLine();
case 2:
columnNamesFromFile = _context2.sent;
if (!(!this.fullColumnNames && !columnNamesFromFile)) {
_context2.next = 7;
break;
}
throw new Error('Column names must be provided if there is no header line.');
case 7:
if (this.fullColumnNames && columnNamesFromFile) {
// Check provided columnNames match header line.
assert(columnNamesFromFile.length === this.fullColumnNames.length, function () {
return 'The length of provided columnNames (' + _this2.fullColumnNames.length.toString() + ') does not match the length of the header line read from ' + 'file (' + columnNamesFromFile.length.toString() + ').';
});
}
case 8:
if (!this.fullColumnNames) {
this.fullColumnNames = columnNamesFromFile;
} // Check if there are duplicate column names.
counts = this.fullColumnNames.reduce(function (countAcc, name) {
countAcc[name] = countAcc[name] + 1 || 1;
return countAcc;
}, {});
duplicateNames = Object.keys(counts).filter(function (name) {
return counts[name] > 1;
});
assert(duplicateNames.length === 0, function () {
return 'Duplicate column names found: ' + duplicateNames.toString();
}); // Check if keys in columnConfigs match columnNames.
if (!this.columnConfigs) {
_context2.next = 22;
break;
}
_i = 0, _Object$keys = Object.keys(this.columnConfigs);
case 14:
if (!(_i < _Object$keys.length)) {
_context2.next = 22;
break;
}
key = _Object$keys[_i];
index = this.fullColumnNames.indexOf(key);
if (!(index === -1)) {
_context2.next = 19;
break;
}
throw new Error('The key "' + key + '" provided in columnConfigs does not match any of the column ' + 'names (' + this.fullColumnNames.toString() + ').');
case 19:
_i++;
_context2.next = 14;
break;
case 22:
this.columnNamesValidated = true;
case 23:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function setColumnNames() {
return _setColumnNames.apply(this, arguments);
}
return setColumnNames;
}();
_proto.maybeReadHeaderLine = /*#__PURE__*/function () {
var _maybeReadHeaderLine = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var iter, firstElement, firstLine, headers;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (!this.hasHeader) {
_context3.next = 14;
break;
}
_context3.next = 3;
return this.base.iterator();
case 3:
iter = _context3.sent;
_context3.next = 6;
return iter.next();
case 6:
firstElement = _context3.sent;
if (!firstElement.done) {
_context3.next = 9;
break;
}
throw new Error('No data was found for CSV parsing.');
case 9:
firstLine = firstElement.value;
headers = this.parseRow(firstLine, false);
return _context3.abrupt("return", headers);
case 14:
return _context3.abrupt("return", null);
case 15:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function maybeReadHeaderLine() {
return _maybeReadHeaderLine.apply(this, arguments);
}
return maybeReadHeaderLine;
}();
_proto.iterator = /*#__PURE__*/function () {
var _iterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4() {
var _this3 = this;
var lines;
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
if (this.columnNamesValidated) {
_context4.next = 3;
break;
}
_context4.next = 3;
return this.setColumnNames();
case 3:
_context4.next = 5;
return this.base.iterator();
case 5:
lines = _context4.sent;
if (this.hasHeader) {
// We previously read the first line to get the columnNames.
// Now that we're providing data, skip it.
lines = lines.skip(1);
}
return _context4.abrupt("return", lines.map(function (x) {
return _this3.makeDataElement(x);
}));
case 8:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function iterator() {
return _iterator.apply(this, arguments);
}
return iterator;
}();
_proto.makeDataElement = function makeDataElement(line) {
var values = this.parseRow(line);
var features = {};
var labels = {};
for (var i = 0; i < this.fullColumnNames.length; i++) {
var key = this.fullColumnNames[i];
var config = this.columnConfigs ? this.columnConfigs[key] : null;
if (this.configuredColumnsOnly && !config) {
// This column is not selected.
continue;
} else {
var value = values[i];
var parsedValue = null;
if (value === '') {
// If default value is provided, use it. If default value is not
// provided, set as undefined.
if (config && config.default !== undefined) {
parsedValue = config.default;
} else if (config && (config.required || config.isLabel)) {
throw new Error("Required column " + key + " is empty in this line: " + line);
} else {
parsedValue = undefined;
}
} else {
// A value is present, so parse it based on type
var valueAsNum = Number(value);
if (isNaN(valueAsNum)) {
// The value is a string and this column is declared as boolean
// in config, parse it as boolean.
if (config && config.dtype === 'bool') {
parsedValue = this.getBoolean(value);
} else {
// Set value as string
parsedValue = value;
}
} else if (!config || !config.dtype) {
// If this value is a number and no type config is provided, return
// it as number.
parsedValue = valueAsNum;
} else {
// If this value is a number and data type is provided, parse it
// according to provided data type.
switch (config.dtype) {
case 'float32':
parsedValue = valueAsNum;
break;
case 'int32':
parsedValue = Math.floor(valueAsNum);
break;
case 'bool':
parsedValue = this.getBoolean(value);
break;
default:
parsedValue = valueAsNum;
}
}
} // Check if this column is label.
config && config.isLabel ? labels[key] = parsedValue : features[key] = parsedValue;
}
} // If label exists, return an object of features and labels as {xs:features,
// ys:labels}, otherwise return features only.
if (Object.keys(labels).length === 0) {
return features;
} else {
return {
xs: features,
ys: labels
};
}
};
_proto.getBoolean = function getBoolean(value) {
if (value === '1' || value.toLowerCase() === 'true') {
return 1;
} else {
return 0;
}
} // adapted from https://beta.observablehq.com/@mbostock/streaming-csv
;
_proto.parseRow = function parseRow(line, validateElementCount) {
if (validateElementCount === void 0) {
validateElementCount = true;
}
var result = [];
var readOffset = 0;
var readLength = line.length;
var currentState = STATE_OUT; // Goes through the line to parse quote.
for (var i = 0; i < readLength; i++) {
switch (currentState) {
// Before enter a new field
case STATE_OUT:
switch (line.charAt(i)) {
// Enter a quoted field
case CODE_QUOTE:
readOffset = i + 1;
currentState = STATE_QUOTE;
break;
// Read an empty field
case this.delimiter:
readOffset = i + 1; // If delimiter is white space and configured to collapse
// multiple white spaces, ignore this white space.
if (this.delimiter === ' ' && this.delimWhitespace) {
break;
}
result.push('');
currentState = STATE_OUT;
break;
// Enter an unquoted field
default:
currentState = STATE_FIELD;
readOffset = i;
break;
}
break;
// In an unquoted field
case STATE_FIELD:
switch (line.charAt(i)) {
// Exit an unquoted field, add it to result
case this.delimiter:
result.push(line.substring(readOffset, i));
currentState = STATE_OUT;
readOffset = i + 1;
break;
default:
}
break;
// In a quoted field
case STATE_QUOTE:
switch (line.charAt(i)) {
// Read a quote after a quote
case CODE_QUOTE:
currentState = STATE_QUOTE_AFTER_QUOTE;
break;
default:
}
break;
// This state means it's right after a second quote in a field
case STATE_QUOTE_AFTER_QUOTE:
switch (line.charAt(i)) {
// Finished a quoted field
case this.delimiter:
result.push(line.substring(readOffset, i - 1));
currentState = STATE_OUT;
readOffset = i + 1;
break;
// Finished a quoted part in a quoted field
case CODE_QUOTE:
currentState = STATE_QUOTE;
break;
// In a quoted part in a quoted field
default:
currentState = STATE_WITHIN_QUOTE_IN_QUOTE;
break;
}
break;
case STATE_WITHIN_QUOTE_IN_QUOTE:
switch (line.charAt(i)) {
// Exit a quoted part in a quoted field
case CODE_QUOTE:
currentState = STATE_QUOTE;
break;
default:
}
break;
default:
}
} // Adds last item based on if it is quoted.
if (currentState === STATE_QUOTE_AFTER_QUOTE) {
result.push(line.substring(readOffset, readLength - 1));
} else {
result.push(line.substring(readOffset));
} // Check if each row has the same number of elements as column names.
if (validateElementCount && result.length !== this.fullColumnNames.length) {
throw new Error("Invalid row in csv file. Should have " + this.fullColumnNames.length + " elements in a row, but got " + result);
}
return result;
};
return CSVDataset;
}(Dataset); // TODO(soergel): add more basic datasets for parity with tf.data
// tf.data.FixedLengthRecordDataset()
// tf.data.TFRecordDataset()
/**
* Provide a stream of tensors from microphone audio stream. The tensors are
* representing audio data as frequency-domain spectrogram generated with
* browser's native FFT. Tensors representing time-domain waveform is available
* based on configuration. Only works in browser environment.
*/
var MicrophoneIterator = /*#__PURE__*/function (_LazyIterator) {
_inheritsLoose(MicrophoneIterator, _LazyIterator);
function MicrophoneIterator(microphoneConfig) {
var _this;
_this = _LazyIterator.call(this) || this;
_this.microphoneConfig = microphoneConfig;
_this.isClosed = false;
_this.fftSize = microphoneConfig.fftSize || 1024;
var fftSizeLog2 = Math.log2(_this.fftSize);
if (_this.fftSize < 0 || fftSizeLog2 < 4 || fftSizeLog2 > 14 || !Number.isInteger(fftSizeLog2)) {
throw new Error("Invalid fftSize: it must be a power of 2 between " + ("2 to 4 and 2 to 14, but got " + _this.fftSize));
}
_this.numFrames = microphoneConfig.numFramesPerSpectrogram || 43;
_this.sampleRateHz = microphoneConfig.sampleRateHz;
_this.columnTruncateLength = microphoneConfig.columnTruncateLength || _this.fftSize;
_this.audioTrackConstraints = microphoneConfig.audioTrackConstraints;
_this.smoothingTimeConstant = microphoneConfig.smoothingTimeConstant || 0;
_this.includeSpectrogram = microphoneConfig.includeSpectrogram === false ? false : true;
_this.includeWaveform = microphoneConfig.includeWaveform === true ? true : false;
if (!_this.includeSpectrogram && !_this.includeWaveform) {
throw new Error('Both includeSpectrogram and includeWaveform are false. ' + 'At least one type of data should be returned.');
}
return _this;
}
var _proto = MicrophoneIterator.prototype;
_proto.summary = function summary() {
return "microphone";
} // Construct a MicrophoneIterator and start the audio stream.
;
MicrophoneIterator.create =
/*#__PURE__*/
function () {
var _create = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(microphoneConfig) {
var microphoneIterator;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (microphoneConfig === void 0) {
microphoneConfig = {};
}
if (!env().get('IS_NODE')) {
_context.next = 3;
break;
}
throw new Error('microphone API is only supported in browser environment.');
case 3:
microphoneIterator = new MicrophoneIterator(microphoneConfig); // Call async function start() to initialize the audio stream.
_context.next = 6;
return microphoneIterator.start();
case 6:
return _context.abrupt("return", microphoneIterator);
case 7:
case "end":
return _context.stop();
}
}
}, _callee);
}));
function create(_x) {
return _create.apply(this, arguments);
}
return create;
}() // Start the audio stream and FFT.
;
_proto.start =
/*#__PURE__*/
function () {
var _start = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var ctxConstructor, streamSource;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.prev = 0;
_context2.next = 3;
return navigator.mediaDevices.getUserMedia({
audio: this.audioTrackConstraints == null ? true : this.audioTrackConstraints,
video: false
});
case 3:
this.stream = _context2.sent;
_context2.next = 9;
break;
case 6:
_context2.prev = 6;
_context2.t0 = _context2["catch"](0);
throw new Error("Error thrown while initializing video stream: " + _context2.t0.message);
case 9:
if (this.stream) {
_context2.next = 11;
break;
}
throw new Error('Could not obtain audio from microphone.');
case 11:
ctxConstructor = // tslint:disable-next-line:no-any
window.AudioContext || window.webkitAudioContext;
this.audioContext = new ctxConstructor();
if (this.sampleRateHz) {
_context2.next = 17;
break;
}
// If sample rate is not provided, use the available sample rate on
// device.
this.sampleRateHz = this.audioContext.sampleRate;
_context2.next = 19;
break;
case 17:
if (!(this.audioContext.sampleRate !== this.sampleRateHz)) {
_context2.next = 19;
break;
}
throw new Error("Mismatch in sampling rate: " + ("Expected: " + this.sampleRateHz + "; ") + ("Actual: " + this.audioContext.sampleRate));
case 19:
streamSource = this.audioContext.createMediaStreamSource(this.stream);
this.analyser = this.audioContext.createAnalyser();
this.analyser.fftSize = this.fftSize * 2;
this.analyser.smoothingTimeConstant = this.smoothingTimeConstant;
streamSource.connect(this.analyser);
this.freqData = new Float32Array(this.fftSize);
this.timeData = new Float32Array(this.fftSize);
return _context2.abrupt("return");
case 27:
case "end":
return _context2.stop();
}
}
}, _callee2, this, [[0, 6]]);
}));
function start() {
return _start.apply(this, arguments);
}
return start;
}();
_proto.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var spectrogramTensor, waveformTensor, audioDataQueue, freqData, timeData;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (!this.isClosed) {
_context3.next = 2;
break;
}
return _context3.abrupt("return", {
value: null,
done: true
});
case 2:
_context3.next = 4;
return this.getAudioData();
case 4:
audioDataQueue = _context3.sent;
if (this.includeSpectrogram) {
freqData = this.flattenQueue(audioDataQueue.freqDataQueue);
spectrogramTensor = this.getTensorFromAudioDataArray(freqData, [this.numFrames, this.columnTruncateLength, 1]);
}
if (this.includeWaveform) {
timeData = this.flattenQueue(audioDataQueue.timeDataQueue);
waveformTensor = this.getTensorFromAudioDataArray(timeData, [this.numFrames * this.fftSize, 1]);
}
return _context3.abrupt("return", {
value: {
'spectrogram': spectrogramTensor,
'waveform': waveformTensor
},
done: false
});
case 8:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}() // Capture one result from the audio stream, and extract the value from
// iterator.next() result.
;
_proto.capture =
/*#__PURE__*/
function () {
var _capture = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4() {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
_context4.next = 2;
return this.next();
case 2:
return _context4.abrupt("return", _context4.sent.value);
case 3:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function capture() {
return _capture.apply(this, arguments);
}
return capture;
}();
_proto.getAudioData = /*#__PURE__*/function () {
var _getAudioData = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee5() {
var _this2 = this;
var freqDataQueue, timeDataQueue, currentFrames;
return regeneratorRuntime.wrap(function _callee5$(_context5) {
while (1) {
switch (_context5.prev = _context5.next) {
case 0:
freqDataQueue = [];
timeDataQueue = [];
currentFrames = 0;
return _context5.abrupt("return", new Promise(function (resolve) {
var intervalID = setInterval(function () {
if (_this2.includeSpectrogram) {
_this2.analyser.getFloatFrequencyData(_this2.freqData); // If the audio stream is initializing, return empty queue.
if (_this2.freqData[0] === -Infinity) {
resolve({
freqDataQueue: freqDataQueue,
timeDataQueue: timeDataQueue
});
}
freqDataQueue.push(_this2.freqData.slice(0, _this2.columnTruncateLength));
}
if (_this2.includeWaveform) {
_this2.analyser.getFloatTimeDomainData(_this2.timeData);
timeDataQueue.push(_this2.timeData.slice());
} // Clean interval and return when all frames have been collected
if (++currentFrames === _this2.numFrames) {
clearInterval(intervalID);
resolve({
freqDataQueue: freqDataQueue,
timeDataQueue: timeDataQueue
});
}
}, _this2.fftSize / _this2.sampleRateHz * 1e3);
}));
case 4:
case "end":
return _context5.stop();
}
}
}, _callee5);
}));
function getAudioData() {
return _getAudioData.apply(this, arguments);
}
return getAudioData;
}() // Stop the audio stream and pause the iterator.
;
_proto.stop = function stop() {
if (!this.isClosed) {
this.isClosed = true;
this.analyser.disconnect();
this.audioContext.close();
if (this.stream != null && this.stream.getTracks().length > 0) {
this.stream.getTracks()[0].stop();
}
}
} // Override toArray() function to prevent collecting.
;
_proto.toArray = function toArray() {
throw new Error('Can not convert infinite audio stream to array.');
} // Return audio sampling rate in Hz
;
_proto.getSampleRate = function getSampleRate() {
return this.sampleRateHz;
};
_proto.flattenQueue = function flattenQueue(queue) {
var frameSize = queue[0].length;
var freqData = new Float32Array(queue.length * frameSize);
queue.forEach(function (data, i) {
return freqData.set(data, i * frameSize);
});
return freqData;
};
_proto.getTensorFromAudioDataArray = function getTensorFromAudioDataArray(freqData, shape) {
var vals = new Float32Array(sizeFromShape(shape)); // If the data is less than the output shape, the rest is padded with zeros.
vals.set(freqData, vals.length - freqData.length);
return tensor(vals, shape);
};
return MicrophoneIterator;
}(LazyIterator);
/**
* Provide a stream of image tensors from webcam video stream. Only works in
* browser environment.
*/
var WebcamIterator = /*#__PURE__*/function (_LazyIterator) {
_inheritsLoose(WebcamIterator, _LazyIterator);
function WebcamIterator(webcamVideoElement, webcamConfig) {
var _this;
_this = _LazyIterator.call(this) || this;
_this.webcamVideoElement = webcamVideoElement;
_this.webcamConfig = webcamConfig;
_this.isClosed = true;
_this.resize = false;
if (_this.needToResize()) {
_this.resize = true;
_this.cropSize = [_this.webcamConfig.resizeHeight, _this.webcamConfig.resizeWidth];
_this.cropBoxInd = tensor1d([0], 'int32');
if (_this.webcamConfig.centerCrop) {
// Calculate the box based on resizing shape.
var widthCroppingRatio = _this.webcamConfig.resizeWidth * 1.0 / _this.webcamVideoElement.width;
var heightCroppingRatio = _this.webcamConfig.resizeHeight * 1.0 / _this.webcamVideoElement.height;
var widthCropStart = (1 - widthCroppingRatio) / 2;
var heightCropStart = (1 - heightCroppingRatio) / 2;
var widthCropEnd = widthCropStart + widthCroppingRatio;
var heightCropEnd = heightCroppingRatio + heightCropStart;
_this.cropBox = tensor2d([heightCropStart, widthCropStart, heightCropEnd, widthCropEnd], [1, 4]);
} else {
_this.cropBox = tensor2d([0, 0, 1, 1], [1, 4]);
}
}
return _this;
}
var _proto = WebcamIterator.prototype;
_proto.summary = function summary() {
return "webcam";
} // Construct a WebcamIterator and start it's video stream.
;
WebcamIterator.create =
/*#__PURE__*/
function () {
var _create = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(webcamVideoElement, webcamConfig) {
var webcamIterator;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (webcamConfig === void 0) {
webcamConfig = {};
}
if (!env().get('IS_NODE')) {
_context.next = 3;
break;
}
throw new Error('tf.data.webcam is only supported in browser environment.');
case 3:
if (webcamVideoElement) {
_context.next = 9;
break;
}
// If webcam video element is not provided, create a hidden video element
// with provided width and height.
webcamVideoElement = document.createElement('video');
if (!(!webcamConfig.resizeWidth || !webcamConfig.resizeHeight)) {
_context.next = 7;
break;
}
throw new Error('Please provide webcam video element, or resizeWidth and ' + 'resizeHeight to create a hidden video element.');
case 7:
webcamVideoElement.width = webcamConfig.resizeWidth;
webcamVideoElement.height = webcamConfig.resizeHeight;
case 9:
webcamIterator = new WebcamIterator(webcamVideoElement, webcamConfig); // Call async function to initialize the video stream.
_context.next = 12;
return webcamIterator.start();
case 12:
return _context.abrupt("return", webcamIterator);
case 13:
case "end":
return _context.stop();
}
}
}, _callee);
}));
function create(_x, _x2) {
return _create.apply(this, arguments);
}
return create;
}() // Async function to start video stream.
;
_proto.start =
/*#__PURE__*/
function () {
var _start = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var _this2 = this;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
if (this.webcamConfig.facingMode) {
assert(this.webcamConfig.facingMode === 'user' || this.webcamConfig.facingMode === 'environment', function () {
return "Invalid webcam facing mode: " + _this2.webcamConfig.facingMode + ". " + "Please provide 'user' or 'environment'";
});
}
_context2.prev = 1;
_context2.next = 4;
return navigator.mediaDevices.getUserMedia({
video: {
deviceId: this.webcamConfig.deviceId,
facingMode: this.webcamConfig.facingMode ? this.webcamConfig.facingMode : 'user',
width: this.webcamVideoElement.width,
height: this.webcamVideoElement.height
}
});
case 4:
this.stream = _context2.sent;
_context2.next = 11;
break;
case 7:
_context2.prev = 7;
_context2.t0 = _context2["catch"](1);
// Modify the error message but leave the stack trace intact
_context2.t0.message = "Error thrown while initializing video stream: " + _context2.t0.message;
throw _context2.t0;
case 11:
if (this.stream) {
_context2.next = 13;
break;
}
throw new Error('Could not obtain video from webcam.');
case 13:
// Older browsers may not have srcObject
try {
this.webcamVideoElement.srcObject = this.stream;
} catch (error) {
console.log(error);
this.webcamVideoElement.src = window.URL.createObjectURL(this.stream);
} // Start the webcam video stream
this.webcamVideoElement.play();
this.isClosed = false;
return _context2.abrupt("return", new Promise(function (resolve) {
// Add event listener to make sure the webcam has been fully initialized.
_this2.webcamVideoElement.onloadedmetadata = function () {
resolve();
};
}));
case 17:
case "end":
return _context2.stop();
}
}
}, _callee2, this, [[1, 7]]);
}));
function start() {
return _start.apply(this, arguments);
}
return start;
}();
_proto.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3() {
var img;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (!this.isClosed) {
_context3.next = 2;
break;
}
return _context3.abrupt("return", {
value: null,
done: true
});
case 2:
_context3.prev = 2;
img = fromPixels(this.webcamVideoElement);
_context3.next = 9;
break;
case 6:
_context3.prev = 6;
_context3.t0 = _context3["catch"](2);
throw new Error("Error thrown converting video to pixels: " + JSON.stringify(_context3.t0));
case 9:
if (!this.resize) {
_context3.next = 22;
break;
}
_context3.prev = 10;
return _context3.abrupt("return", {
value: this.cropAndResizeFrame(img),
done: false
});
case 14:
_context3.prev = 14;
_context3.t1 = _context3["catch"](10);
throw new Error("Error thrown cropping the video: " + _context3.t1.message);
case 17:
_context3.prev = 17;
img.dispose();
return _context3.finish(17);
case 20:
_context3.next = 23;
break;
case 22:
return _context3.abrupt("return", {
value: img,
done: false
});
case 23:
case "end":
return _context3.stop();
}
}
}, _callee3, this, [[2, 6], [10, 14, 17, 20]]);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}();
_proto.needToResize = function needToResize() {
// If resizeWidth and resizeHeight are provided, and different from the
// width and height of original HTMLVideoElement, then resizing and cropping
// is required.
if (this.webcamConfig.resizeWidth && this.webcamConfig.resizeHeight && (this.webcamVideoElement.width !== this.webcamConfig.resizeWidth || this.webcamVideoElement.height !== this.webcamConfig.resizeHeight)) {
return true;
}
return false;
} // Cropping and resizing each frame based on config
;
_proto.cropAndResizeFrame = function cropAndResizeFrame(img) {
var _this3 = this;
return tidy(function () {
var expandedImage = expandDims(cast(img, 'float32'), 0);
var resizedImage;
resizedImage = image.cropAndResize(expandedImage, _this3.cropBox, _this3.cropBoxInd, _this3.cropSize, 'bilinear'); // Extract image from batch cropping.
var shape = resizedImage.shape;
return reshape(resizedImage, shape.slice(1));
});
} // Capture one frame from the video stream, and extract the value from
// iterator.next() result.
;
_proto.capture =
/*#__PURE__*/
function () {
var _capture = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4() {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
_context4.next = 2;
return this.next();
case 2:
return _context4.abrupt("return", _context4.sent.value);
case 3:
case "end":
return _context4.stop();
}
}
}, _callee4, this);
}));
function capture() {
return _capture.apply(this, arguments);
}
return capture;
}() // Stop the video stream and pause webcam iterator.
;
_proto.stop = function stop() {
var tracks = this.stream.getTracks();
tracks.forEach(function (track) {
return track.stop();
});
try {
this.webcamVideoElement.srcObject = null;
} catch (error) {
console.log(error);
this.webcamVideoElement.src = null;
}
this.isClosed = true;
} // Override toArray() function to prevent collecting.
;
_proto.toArray = function toArray() {
throw new Error('Can not convert infinite video stream to array.');
};
return WebcamIterator;
}(LazyIterator);
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
/**
* Represents a data source readable as a stream of binary data chunks.
*
* Because `Dataset`s can be read repeatedly (via `Dataset.iterator()`), this
* provides a means to repeatedly create streams from the underlying data
* sources.
*/
var DataSource = function DataSource() {}; // TODO(soergel): consider convenience factory functions here
// in combination with chainable source->dataset above, e.g.:
// tf.data.url(...).asCsvDataset().shuffle().batch()
var StringIterator = /*#__PURE__*/function (_LazyIterator) {
_inheritsLoose(StringIterator, _LazyIterator);
function StringIterator() {
return _LazyIterator.apply(this, arguments) || this;
}
var _proto = StringIterator.prototype;
/**
* Splits a string stream on a given separator.
*
* It is assumed that the incoming chunk boundaries have no semantic meaning,
* so conceptually the incoming stream is treated simply as the concatenation
* of its elements.
*
* The outgoing stream provides chunks corresponding to the results of the
* standard string split() operation (even if such a chunk spanned incoming
* chunks). The separators are not included.
*
* A typical usage is to split a text file (represented as a stream with
* arbitrary chunk boundaries) into lines.
*
* @param upstream A readable stream of strings that can be treated as
* concatenated.
* @param separator A character to split on.
*/
_proto.split = function split(separator) {
return new SplitIterator(this, separator);
};
return StringIterator;
}(LazyIterator); // ============================================================================
// The following private classes serve to implement the chainable methods
// on StringIterator. Unfortunately they can't be placed in separate files, due
// to resulting trouble with circular imports.
// ============================================================================
// We wanted multiple inheritance, e.g.
// class SplitIterator extends QueueIterator<string>, StringIterator
// but the TypeScript mixin approach is a bit hacky, so we take this adapter
// approach instead.
var SplitIterator = /*#__PURE__*/function (_StringIterator) {
_inheritsLoose(SplitIterator, _StringIterator);
function SplitIterator(upstream, separator) {
var _this;
_this = _StringIterator.call(this) || this;
_this.upstream = upstream;
_this.impl = new SplitIteratorImpl(upstream, separator);
return _this;
}
var _proto2 = SplitIterator.prototype;
_proto2.summary = function summary() {
return this.impl.summary();
};
_proto2.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", this.impl.next());
case 1:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}();
return SplitIterator;
}(StringIterator);
var SplitIteratorImpl = /*#__PURE__*/function (_OneToManyIterator) {
_inheritsLoose(SplitIteratorImpl, _OneToManyIterator);
function SplitIteratorImpl(upstream, separator) {
var _this2;
_this2 = _OneToManyIterator.call(this) || this;
_this2.upstream = upstream;
_this2.separator = separator; // A partial string at the end of an upstream chunk
_this2.carryover = '';
return _this2;
}
var _proto3 = SplitIteratorImpl.prototype;
_proto3.summary = function summary() {
return this.upstream.summary() + " -> Split('" + this.separator + "')";
};
_proto3.pump = /*#__PURE__*/function () {
var _pump = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var chunkResult, lines, _iterator, _step, line;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.upstream.next();
case 2:
chunkResult = _context2.sent;
if (!chunkResult.done) {
_context2.next = 9;
break;
}
if (!(this.carryover === '')) {
_context2.next = 6;
break;
}
return _context2.abrupt("return", false);
case 6:
// Pretend that the pump succeeded in order to emit the small last batch.
// The next pump() call will actually fail.
this.outputQueue.push(this.carryover);
this.carryover = '';
return _context2.abrupt("return", true);
case 9:
lines = chunkResult.value.split(this.separator); // Note the behavior: " ab ".split(' ') === ['', 'ab', '']
// Thus the carryover may be '' if the separator falls on a chunk
// boundary; this produces the correct result.
lines[0] = this.carryover + lines[0];
for (_iterator = _createForOfIteratorHelperLoose(lines.slice(0, -1)); !(_step = _iterator()).done;) {
line = _step.value;
this.outputQueue.push(line);
}
this.carryover = lines[lines.length - 1];
return _context2.abrupt("return", true);
case 14:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function pump() {
return _pump.apply(this, arguments);
}
return pump;
}();
return SplitIteratorImpl;
}(OneToManyIterator);
var ByteChunkIterator = /*#__PURE__*/function (_LazyIterator) {
_inheritsLoose(ByteChunkIterator, _LazyIterator);
function ByteChunkIterator() {
return _LazyIterator.apply(this, arguments) || this;
}
var _proto = ByteChunkIterator.prototype;
/**
* Decode a stream of UTF8-encoded byte arrays to a stream of strings.
*
* The byte arrays producetd from the ByteChunkIterator on which this is
* called will be interpreted as concatenated. No assumptions are made about
* the boundaries of the incoming chunks, so a multi-byte UTF8 encoding of a
* character may span the boundary between chunks. This naturally happens,
* for instance, when reading fixed-size byte arrays from a file.
*/
_proto.decodeUTF8 = function decodeUTF8() {
return new Utf8Iterator(this);
};
return ByteChunkIterator;
}(LazyIterator); // ============================================================================
// The following private classes serve to implement the chainable methods
// on ByteChunkIterator. Unfortunately they can't be placed in separate files,
// due to resulting trouble with circular imports.
// ============================================================================
// We wanted multiple inheritance, e.g.
// class Utf8Iterator extends QueueIterator<string>, StringIterator
// but the TypeScript mixin approach is a bit hacky, so we take this adapter
// approach instead.
var Utf8Iterator = /*#__PURE__*/function (_StringIterator) {
_inheritsLoose(Utf8Iterator, _StringIterator);
function Utf8Iterator(upstream) {
var _this;
_this = _StringIterator.call(this) || this;
_this.upstream = upstream;
_this.impl = new Utf8IteratorImpl(upstream);
return _this;
}
var _proto2 = Utf8Iterator.prototype;
_proto2.summary = function summary() {
return this.impl.summary();
};
_proto2.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", this.impl.next());
case 1:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}();
return Utf8Iterator;
}(StringIterator);
/**
* Decode a stream of UTF8-encoded byte arrays to a stream of strings.
*
* This is tricky because the incoming byte array boundaries may disrupt a
* multi-byte UTF8 character. Thus any incomplete character data at the end of
* a chunk must be carried over and prepended to the next chunk before
* decoding. Luckily with native decoder, TextDecoder in browser and
* string_decoder in node, byte array boundaries are handled automatically.
*
* In the context of an input pipeline for machine learning, UTF8 decoding is
* needed to parse text files containing training examples or prediction
* requests (e.g., formatted as CSV or JSON). We cannot use the built-in
* decoding provided by FileReader.readAsText() because here we are in a
* streaming context, which FileReader does not support.
*
* @param upstream A `LazyIterator` of `Uint8Arrays` containing UTF8-encoded
* text, which should be interpreted as concatenated. No assumptions are
* made about the boundaries of the incoming chunks, so a multi-byte UTF8
* encoding of a character may span the boundary between chunks. This
* naturally happens, for instance, when reading fixed-size byte arrays from a
* file.
*/
var Utf8IteratorImpl = /*#__PURE__*/function (_OneToManyIterator) {
_inheritsLoose(Utf8IteratorImpl, _OneToManyIterator);
function Utf8IteratorImpl(upstream) {
var _this2;
_this2 = _OneToManyIterator.call(this) || this;
_this2.upstream = upstream;
if (env().get('IS_BROWSER')) {
_this2.decoder = new TextDecoder('utf-8');
} else {
// tslint:disable-next-line:no-require-imports
var _require = require('string_decoder'),
StringDecoder = _require.StringDecoder;
_this2.decoder = new StringDecoder('utf8');
}
return _this2;
}
var _proto3 = Utf8IteratorImpl.prototype;
_proto3.summary = function summary() {
return this.upstream.summary() + " -> Utf8";
};
_proto3.pump = /*#__PURE__*/function () {
var _pump = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var chunkResult, chunk, text;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return this.upstream.next();
case 2:
chunkResult = _context2.sent;
if (!chunkResult.done) {
_context2.next = 7;
break;
}
return _context2.abrupt("return", false);
case 7:
chunk = chunkResult.value;
case 8:
if (env().get('IS_BROWSER')) {
text = this.decoder.decode(chunk, {
stream: true
});
} else {
text = this.decoder.write(Buffer.from(chunk.buffer));
}
this.outputQueue.push(text);
return _context2.abrupt("return", true);
case 11:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function pump() {
return _pump.apply(this, arguments);
}
return pump;
}();
return Utf8IteratorImpl;
}(OneToManyIterator);
/**
* Provide a stream of chunks from a File, Blob, or Uint8Array.
* @param file The source File, Blob or Uint8Array.
* @param options Optional settings controlling file reading.
* @returns a lazy Iterator of Uint8Arrays containing sequential chunks of the
* input File, Blob or Uint8Array.
*/
var FileChunkIterator = /*#__PURE__*/function (_ByteChunkIterator) {
_inheritsLoose(FileChunkIterator, _ByteChunkIterator);
function FileChunkIterator(file, options) {
var _this;
if (options === void 0) {
options = {};
}
_this = _ByteChunkIterator.call(this) || this;
_this.file = file;
_this.options = options;
assert(file instanceof Uint8Array || (env().get('IS_BROWSER') ? file instanceof File || file instanceof Blob : false), function () {
return 'FileChunkIterator only supports File, Blob and Uint8Array ' + 'right now.';
});
_this.offset = options.offset || 0; // default 1MB chunk has tolerable perf on large files
_this.chunkSize = options.chunkSize || 1024 * 1024;
return _this;
}
var _proto = FileChunkIterator.prototype;
_proto.summary = function summary() {
return "FileChunks " + this.file;
};
_proto.next = /*#__PURE__*/function () {
var _next = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var _this2 = this;
var chunk;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!(this.offset >= (this.file instanceof Uint8Array ? this.file.byteLength : this.file.size))) {
_context.next = 2;
break;
}
return _context.abrupt("return", {
value: null,
done: true
});
case 2:
chunk = new Promise(function (resolve, reject) {
var end = _this2.offset + _this2.chunkSize;
if (_this2.file instanceof Uint8Array) {
// Note if end > this.uint8Array.byteLength, we just get a small last
// chunk.
resolve(new Uint8Array(_this2.file.slice(_this2.offset, end)));
} else {
// This branch assumes that this.file type is File or Blob, which
// means it is in the browser environment.
// TODO(soergel): is this a performance issue?
var fileReader = new FileReader();
fileReader.onload = function (event) {
var data = fileReader.result; // Not sure we can trust the return type of
// FileReader.readAsArrayBuffer See e.g.
// https://github.com/node-file-api/FileReader/issues/2
if (data instanceof ArrayBuffer) {
data = new Uint8Array(data);
}
if (!(data instanceof Uint8Array)) {
return reject(new TypeError('FileReader returned unknown type.'));
}
resolve(data);
};
fileReader.onabort = function (event) {
return reject(new Error('Aborted'));
};
fileReader.onerror = function (event) {
return reject(new Error(event.type));
}; // TODO(soergel): better handle onabort, onerror
// Note if end > this.file.size, we just get a small last chunk.
var slice = _this2.file.slice(_this2.offset, end); // We can't use readAsText here (even if we know the file is text)
// because the slice boundary may fall within a multi-byte character.
fileReader.readAsArrayBuffer(slice);
}
_this2.offset = end;
});
_context.next = 5;
return chunk;
case 5:
_context.t0 = _context.sent;
return _context.abrupt("return", {
value: _context.t0,
done: false
});
case 7:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function next() {
return _next.apply(this, arguments);
}
return next;
}();
return FileChunkIterator;
}(ByteChunkIterator);
/**
* Provide a stream of chunks from a URL.
*
* Note this class first downloads the entire file into memory before providing
* the first element from the stream. This is because the Fetch API does not
* yet reliably provide a reader stream for the response body.
*/
function urlChunkIterator(_x, _x2) {
return _urlChunkIterator.apply(this, arguments);
} // Generate RequestInit from Request to match tf.util.fetch signature.
function _urlChunkIterator() {
_urlChunkIterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(url, options) {
var urlString, requestInit, response, uint8Array;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (options === void 0) {
options = {};
}
if (typeof url === 'string') {
urlString = url;
} else {
urlString = url.url;
requestInit = getRequestInitFromRequest(url);
}
_context.next = 4;
return fetch$2(urlString, requestInit);
case 4:
response = _context.sent;
if (!response.ok) {
_context.next = 14;
break;
}
_context.t0 = Uint8Array;
_context.next = 9;
return response.arrayBuffer();
case 9:
_context.t1 = _context.sent;
uint8Array = new _context.t0(_context.t1);
return _context.abrupt("return", new FileChunkIterator(uint8Array, options));
case 14:
throw new Error(response.statusText);
case 15:
case "end":
return _context.stop();
}
}
}, _callee);
}));
return _urlChunkIterator.apply(this, arguments);
}
var getRequestInitFromRequest = function getRequestInitFromRequest(request) {
var init = {
method: request.method,
headers: request.headers,
body: request.body,
mode: request.mode,
credentials: request.credentials,
cache: request.cache,
redirect: request.redirect,
referrer: request.referrer,
integrity: request.integrity
};
return init;
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* =============================================================================
*/
// Skip tslint any type check cause this method is aiming to check type of
// input.
// tslint:disable-next-line:no-any
function isLocalPath(source) {
return typeof source === 'string' && source.substr(0, 7) === 'file://';
}
/**
* Represents a file, blob, or Uint8Array readable as a stream of binary data
* chunks.
*/
var FileDataSource = /*#__PURE__*/function (_DataSource) {
_inheritsLoose(FileDataSource, _DataSource);
/**
* Create a `FileDataSource`.
*
* @param input Local file path, or `File`/`Blob`/`Uint8Array` object to
* read. Local file only works in node environment.
* @param options Options passed to the underlying `FileChunkIterator`s,
* such as {chunksize: 1024}.
*/
function FileDataSource(input, options) {
var _this;
if (options === void 0) {
options = {};
}
_this = _DataSource.call(this) || this;
_this.input = input;
_this.options = options;
return _this;
}
var _proto = FileDataSource.prototype;
_proto.iterator = /*#__PURE__*/function () {
var _iterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
var fs;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (isLocalPath(this.input) && env().get('IS_NODE')) {
// tslint:disable-next-line:no-require-imports
fs = require('fs');
this.input = fs.readFileSync(this.input.substr(7));
} // TODO(kangyizhang): Add LocalFileChunkIterator to split local streaming
// with file in browser.
return _context.abrupt("return", new FileChunkIterator(this.input, this.options));
case 2:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function iterator() {
return _iterator.apply(this, arguments);
}
return iterator;
}();
return FileDataSource;
}(DataSource);
/*
* Represents a URL readable as a stream of binary data chunks.
*/
var URLDataSource = /*#__PURE__*/function (_DataSource) {
_inheritsLoose(URLDataSource, _DataSource);
/**
* Create a `URLDataSource`.
*
* @param url A source URL string, or a `Request` object.
* @param options Options passed to the underlying `FileChunkIterator`s,
* such as {chunksize: 1024}.
*/
function URLDataSource(url, fileOptions) {
var _this;
if (fileOptions === void 0) {
fileOptions = {};
}
_this = _DataSource.call(this) || this;
_this.url = url;
_this.fileOptions = fileOptions;
return _this;
} // TODO(soergel): provide appropriate caching options. Currently this
// will download the URL anew for each call to iterator(). Since we have
// to treat the downloaded file as a blob/buffer anyway, we may as well retain
// it-- but that raises GC issues. Also we may want a persistent disk cache.
var _proto = URLDataSource.prototype;
_proto.iterator =
/*#__PURE__*/
function () {
var _iterator = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!isLocalPath(this.url)) {
_context.next = 4;
break;
}
return _context.abrupt("return", new FileDataSource(this.url, this.fileOptions).iterator());
case 4:
return _context.abrupt("return", urlChunkIterator(this.url, this.fileOptions));
case 5:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function iterator() {
return _iterator.apply(this, arguments);
}
return iterator;
}();
return URLDataSource;
}(DataSource);
/**
* Create a `CSVDataset` by reading and decoding CSV file(s) from provided URL
* or local path if it's in Node environment.
*
* Note: If isLabel in columnConfigs is `true` for at least one column, the
* element in returned `CSVDataset` will be an object of
* `{xs:features, ys:labels}`: xs is a dict of features key/value pairs, ys
* is a dict of labels key/value pairs. If no column is marked as label,
* returns a dict of features only.
*
* ```js
* const csvUrl =
* 'https://storage.googleapis.com/tfjs-examples/multivariate-linear-regression/data/boston-housing-train.csv';
*
* async function run() {
* // We want to predict the column "medv", which represents a median value of
* // a home (in $1000s), so we mark it as a label.
* const csvDataset = tf.data.csv(
* csvUrl, {
* columnConfigs: {
* medv: {
* isLabel: true
* }
* }
* });
*
* // Number of features is the number of column names minus one for the label
* // column.
* const numOfFeatures = (await csvDataset.columnNames()).length - 1;
*
* // Prepare the Dataset for training.
* const flattenedDataset =
* csvDataset
* .map(({xs, ys}) =>
* {
* // Convert xs(features) and ys(labels) from object form (keyed by
* // column name) to array form.
* return {xs:Object.values(xs), ys:Object.values(ys)};
* })
* .batch(10);
*
* // Define the model.
* const model = tf.sequential();
* model.add(tf.layers.dense({
* inputShape: [numOfFeatures],
* units: 1
* }));
* model.compile({
* optimizer: tf.train.sgd(0.000001),
* loss: 'meanSquaredError'
* });
*
* // Fit the model using the prepared Dataset
* return model.fitDataset(flattenedDataset, {
* epochs: 10,
* callbacks: {
* onEpochEnd: async (epoch, logs) => {
* console.log(epoch + ':' + logs.loss);
* }
* }
* });
* }
*
* await run();
* ```
*
* @param source URL or local path to get CSV file. If it's a local path, it
* must have prefix `file://` and it only works in node environment.
* @param csvConfig (Optional) A CSVConfig object that contains configurations
* of reading and decoding from CSV file(s).
*
* @doc {
* heading: 'Data',
* subheading: 'Creation',
* namespace: 'data',
* configParamIndices: [1]
* }
*/
function csv(source, csvConfig) {
if (csvConfig === void 0) {
csvConfig = {};
}
return new CSVDataset(new URLDataSource(source), csvConfig);
}
/**
* Create a `Dataset` that produces each element by calling a provided function.
*
* Note that repeated iterations over this `Dataset` may produce different
* results, because the function will be called anew for each element of each
* iteration.
*
* Also, beware that the sequence of calls to this function may be out of order
* in time with respect to the logical order of the Dataset. This is due to the
* asynchronous lazy nature of stream processing, and depends on downstream
* transformations (e.g. .shuffle()). If the provided function is pure, this is
* no problem, but if it is a closure over a mutable state (e.g., a traversal
* pointer), then the order of the produced elements may be scrambled.
*
* ```js
* let i = -1;
* const func = () =>
* ++i < 5 ? {value: i, done: false} : {value: null, done: true};
* const ds = tf.data.func(func);
* await ds.forEachAsync(e => console.log(e));
* ```
*
* @param f A function that produces one data element on each call.
*/
function func(f) {
var iter = iteratorFromFunction(f);
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee() {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", iter);
case 1:
case "end":
return _context.stop();
}
}
}, _callee);
})));
}
/**
* Create a `Dataset` that produces each element from provided JavaScript
* generator, which is a function*
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions),
* or a function that returns an
* iterator
* (https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators#Generator_functions).
*
* The returned iterator should have `.next()` function that returns element in
* format of `{value: TensorContainer, done:boolean}`.
*
* Example of creating a dataset from an iterator factory:
* ```js
* function makeIterator() {
* const numElements = 10;
* let index = 0;
*
* const iterator = {
* next: () => {
* let result;
* if (index < numElements) {
* result = {value: index, done: false};
* index++;
* return result;
* }
* return {value: index, done: true};
* }
* };
* return iterator;
* }
* const ds = tf.data.generator(makeIterator);
* await ds.forEachAsync(e => console.log(e));
* ```
*
* Example of creating a dataset from a generator:
* ```js
* function* dataGenerator() {
* const numElements = 10;
* let index = 0;
* while (index < numElements) {
* const x = index;
* index++;
* yield x;
* }
* }
*
* const ds = tf.data.generator(dataGenerator);
* await ds.forEachAsync(e => console.log(e));
* ```
*
* @param generator A Javascript generator function that returns a JavaScript
* iterator.
*
* @doc {
* heading: 'Data',
* subheading: 'Creation',
* namespace: 'data',
* configParamIndices: [1]
* }
*/
function generator(generator) {
return datasetFromIteratorFn( /*#__PURE__*/_asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2() {
var gen;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
_context2.next = 2;
return generator();
case 2:
gen = _context2.sent;
return _context2.abrupt("return", iteratorFromFunction(function () {
return gen.next();
}));
case 4:
case "end":
return _context2.stop();
}
}
}, _callee2);
})));
}
/**
* Create an iterator that generate `Tensor`s from webcam video stream. This API
* only works in Browser environment when the device has webcam.
*
* Note: this code snippet only works when the device has a webcam. It will
* request permission to open the webcam when running.
* ```js
* const videoElement = document.createElement('video');
* videoElement.width = 100;
* videoElement.height = 100;
* const cam = await tf.data.webcam(videoElement);
* const img = await cam.capture();
* img.print();
* cam.stop();
* ```
*
* @param webcamVideoElement A `HTMLVideoElement` used to play video from
* webcam. If this element is not provided, a hidden `HTMLVideoElement` will
* be created. In that case, `resizeWidth` and `resizeHeight` must be
* provided to set the generated tensor shape.
* @param webcamConfig A `WebcamConfig` object that contains configurations of
* reading and manipulating data from webcam video stream.
*
* @doc {
* heading: 'Data',
* subheading: 'Creation',
* namespace: 'data',
* ignoreCI: true
* }
*/
function webcam(_x, _x2) {
return _webcam.apply(this, arguments);
}
/**
* Create an iterator that generate frequency-domain spectrogram `Tensor`s from
* microphone audio stream with browser's native FFT. This API only works in
* browser environment when the device has microphone.
*
* Note: this code snippet only works when the device has a microphone. It will
* request permission to open the microphone when running.
* ```js
* const mic = await tf.data.microphone({
* fftSize: 1024,
* columnTruncateLength: 232,
* numFramesPerSpectrogram: 43,
* sampleRateHz:44100,
* includeSpectrogram: true,
* includeWaveform: true
* });
* const audioData = await mic.capture();
* const spectrogramTensor = audioData.spectrogram;
* spectrogramTensor.print();
* const waveformTensor = audioData.waveform;
* waveformTensor.print();
* mic.stop();
* ```
*
* @param microphoneConfig A `MicrophoneConfig` object that contains
* configurations of reading audio data from microphone.
*
* @doc {
* heading: 'Data',
* subheading: 'Creation',
* namespace: 'data',
* ignoreCI: true
* }
*/
function _webcam() {
_webcam = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(webcamVideoElement, webcamConfig) {
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
return _context3.abrupt("return", WebcamIterator.create(webcamVideoElement, webcamConfig));
case 1:
case "end":
return _context3.stop();
}
}
}, _callee3);
}));
return _webcam.apply(this, arguments);
}
function microphone(_x3) {
return _microphone.apply(this, arguments);
}
function _microphone() {
_microphone = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee4(microphoneConfig) {
return regeneratorRuntime.wrap(function _callee4$(_context4) {
while (1) {
switch (_context4.prev = _context4.next) {
case 0:
return _context4.abrupt("return", MicrophoneIterator.create(microphoneConfig));
case 1:
case "end":
return _context4.stop();
}
}
}, _callee4);
}));
return _microphone.apply(this, arguments);
}
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$4 = '3.9.0';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var index$1 = {
__proto__: null,
array: array,
Dataset: Dataset,
zip: zip,
CSVDataset: CSVDataset,
TextLineDataset: TextLineDataset,
csv: csv,
func: func,
generator: generator,
microphone: microphone,
webcam: webcam,
FileDataSource: FileDataSource,
URLDataSource: URLDataSource,
version_data: version$4
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function assertNotComplex(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(function (t) {
if (t != null) {
assert(t.dtype !== 'complex64', function () {
return opName + " does not support complex64 tensors in the CPU backend.";
});
}
});
}
var whereImpl$1 = whereImpl;
var MathBackendCPU = /*#__PURE__*/function (_KernelBackend) {
_inheritsLoose(MathBackendCPU, _KernelBackend);
function MathBackendCPU() {
var _this;
_this = _KernelBackend.call(this) || this;
_this.blockSize = 48;
_this.firstUse = true;
_this.data = new DataStorage(_assertThisInitialized(_this), engine());
return _this;
}
var _proto = MathBackendCPU.prototype;
_proto.nextDataId = function nextDataId() {
return MathBackendCPU.nextDataId++;
};
_proto.write = function write(values, shape, dtype) {
if (this.firstUse) {
this.firstUse = false;
if (env().get('IS_NODE')) {
warn('\n============================\n' + 'Hi there 👋. Looks like you are running TensorFlow.js in ' + 'Node.js. To speed things up dramatically, install our node ' + 'backend, which binds to TensorFlow C++, by running ' + 'npm i @tensorflow/tfjs-node, ' + 'or npm i @tensorflow/tfjs-node-gpu if you have CUDA. ' + 'Then call require(\'@tensorflow/tfjs-node\'); (-gpu ' + 'suffix for CUDA) at the start of your program. ' + 'Visit https://github.com/tensorflow/tfjs-node for more details.' + '\n============================');
}
}
var dataId = {
id: this.nextDataId()
};
this.data.set(dataId, {
values: values,
dtype: dtype,
refCount: 1
});
return dataId;
}
/**
* Create a data bucket in cpu backend.
* @param shape Shape of the `TensorInfo`.
* @param dtype DType of the `TensorInfo`.
* @param values The value of the `TensorInfo` stored as a flattened array.
*/
;
_proto.makeTensorInfo = function makeTensorInfo(shape, dtype, values) {
var outId;
if (dtype === 'string' && values != null && values.length > 0 && isString(values[0])) {
var encodedValues = values.map(function (d) {
return encodeString(d);
});
outId = this.write(encodedValues, shape, dtype);
} else {
outId = this.write(values, shape, dtype);
}
return {
dataId: outId,
shape: shape,
dtype: dtype
};
}
/** Return refCount of a `TensorData`. */
;
_proto.refCount = function refCount(dataId) {
if (this.data.has(dataId)) {
var tensorData = this.data.get(dataId);
return tensorData.refCount;
}
return 0;
}
/** Increase refCount of a `TensorData`. */
;
_proto.incRef = function incRef(dataId) {
var tensorData = this.data.get(dataId);
tensorData.refCount++;
}
/** Decrease refCount of a `TensorData`. */
;
_proto.decRef = function decRef(dataId) {
if (this.data.has(dataId)) {
var tensorData = this.data.get(dataId);
tensorData.refCount--;
}
};
_proto.move = function move(dataId, values, shape, dtype, refCount) {
this.data.set(dataId, {
values: values,
dtype: dtype,
refCount: refCount
});
};
_proto.numDataIds = function numDataIds() {
return this.data.numDataIds();
};
_proto.read = /*#__PURE__*/function () {
var _read = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(dataId) {
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
return _context.abrupt("return", this.readSync(dataId));
case 1:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function read(_x) {
return _read.apply(this, arguments);
}
return read;
}();
_proto.readSync = function readSync(dataId) {
var _this$data$get = this.data.get(dataId),
dtype = _this$data$get.dtype,
complexTensorInfos = _this$data$get.complexTensorInfos;
if (dtype === 'complex64') {
var realValues = this.readSync(complexTensorInfos.real.dataId);
var imagValues = this.readSync(complexTensorInfos.imag.dataId);
return mergeRealAndImagArrays(realValues, imagValues);
}
return this.data.get(dataId).values;
};
_proto.bufferSync = function bufferSync(t) {
var data = this.readSync(t.dataId);
var decodedData = data;
if (t.dtype === 'string') {
try {
// Decode the bytes into string.
decodedData = data.map(function (d) {
return decodeString(d);
});
} catch (_a) {
throw new Error('Failed to decode encoded string bytes into utf-8');
}
}
return buffer(t.shape, t.dtype, decodedData);
};
_proto.makeOutput = function makeOutput(values, shape, dtype) {
var dataId = this.write(values, shape, dtype);
return engine().makeTensorFromDataId(dataId, shape, dtype, this);
}
/**
* Dispose the memory if the dataId has 0 refCount. Return true if the memory
* is released or memory is not managed in this backend, false if memory is
* not cleared.
* @param dataId
* @oaram force Optional, remove the data regardless of refCount
*/
;
_proto.disposeData = function disposeData(dataId, force) {
if (force === void 0) {
force = false;
}
if (this.data.has(dataId)) {
this.data.get(dataId).refCount--;
if (!force && this.data.get(dataId).refCount > 0) {
return false;
}
var _this$data$get2 = this.data.get(dataId),
complexTensorInfos = _this$data$get2.complexTensorInfos;
if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId, true);
this.disposeData(complexTensorInfos.imag.dataId, true);
}
this.data.delete(dataId);
}
return true;
};
_proto.disposeIntermediateTensorInfo = function disposeIntermediateTensorInfo(tensorInfo) {
this.disposeData(tensorInfo.dataId);
};
_proto.time = /*#__PURE__*/function () {
var _time = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(f) {
var start, kernelMs;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
start = now();
f();
kernelMs = now() - start;
return _context2.abrupt("return", {
kernelMs: kernelMs
});
case 4:
case "end":
return _context2.stop();
}
}
}, _callee2);
}));
function time(_x2) {
return _time.apply(this, arguments);
}
return time;
}();
_proto.memory = function memory() {
return {
// Unreliable due to automatic gc. The numbers above are cumulative.
unreliable: true,
reasons: ['The reported memory is an upper bound. Due to automatic garbage ' + 'collection, the true allocated memory may be less.']
};
};
_proto.where = function where(condition) {
assertNotComplex([condition], 'where');
var condVals = this.readSync(condition.dataId);
return whereImpl$1(condition.shape, condVals);
};
_proto.dispose = function dispose() {};
_proto.floatPrecision = function floatPrecision() {
return 32;
}
/** Returns the smallest representable number. */
;
_proto.epsilon = function epsilon() {
return _KernelBackend.prototype.epsilon.call(this);
};
return MathBackendCPU;
}(KernelBackend);
MathBackendCPU.nextDataId = 0;
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function simpleAbsImpl(vals) {
var resultValues = new Float32Array(vals.length);
for (var i = 0; i < vals.length; ++i) {
resultValues[i] = Math.abs(vals[i]);
}
return resultValues;
}
var abs$9 = function abs(args) {
var x = args.inputs.x;
var cpuBackend = args.backend;
assertNotComplex(x, 'abs');
var resultValues = new Float32Array(sizeFromShape(x.shape));
var values = cpuBackend.data.get(x.dataId).values;
resultValues = simpleAbsImpl(values);
return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
};
var absConfig = {
kernelName: Abs,
backendName: 'cpu',
kernelFunc: abs$9
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Template that creates implementation for binary ops. Supports broadcast.
*/
function createSimpleBinaryKernelImpl(op) {
return function (aShape, bShape, aVals, bVals, dtype) {
var newShape = assertAndGetBroadcastShape(aShape, bShape);
var resultRank = newShape.length;
var resultStrides = computeStrides(newShape);
var resultSize = sizeFromShape(newShape);
var result = getTypedArrayFromDType(dtype, resultSize);
var aRank = aShape.length;
var bRank = bShape.length;
var aStrides = computeStrides(aShape);
var bStrides = computeStrides(bShape);
var aBroadcastDims = getBroadcastDims(aShape, newShape);
var bBroadcastDims = getBroadcastDims(bShape, newShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (var i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
} else {
var _loop = function _loop(_i) {
var loc = indexToLoc(_i, resultRank, resultStrides);
var aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(function (d) {
return aLoc[d] = 0;
});
var aIndex = locToIndex(aLoc, aRank, aStrides);
var bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(function (d) {
return bLoc[d] = 0;
});
var bIndex = locToIndex(bLoc, bRank, bStrides);
result[_i] = op(aVals[aIndex], bVals[bIndex]);
};
for (var _i = 0; _i < result.length; ++_i) {
_loop(_i);
}
}
return [result, newShape];
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function complex$1(args) {
var inputs = args.inputs,
backend = args.backend;
var real = inputs.real,
imag = inputs.imag;
var realVals = backend.data.get(real.dataId).values;
var imagVals = backend.data.get(imag.dataId).values;
var complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
var complex = backend.data.get(complexInfo.dataId); // The complex tensor owns the underlying real and imag tensorInfos, only the
// complex tensor tracks refCount, when complexData is disposed the
// underlying tensorData will be disposed.
complex.complexTensorInfos = {
real: backend.makeTensorInfo(real.shape, 'float32', realVals),
imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
};
return complexInfo;
}
var complexConfig = {
kernelName: Complex,
backendName: 'cpu',
kernelFunc: complex$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Generates a tensorInfo with all zeros value.
* @param backend cpu backend.
* @param shape Shape for the zeros tensor.
* @param dtype Optional. If set, the result has this dtype.
*/
function zeros$2(backend, shape, dtype) {
if (dtype === void 0) {
dtype = 'float32';
}
if (dtype === 'complex64') {
var real = zeros$2(backend, shape, 'float32');
var imag = zeros$2(backend, shape, 'float32');
return complex$1({
inputs: {
real: real,
imag: imag
},
backend: backend
});
}
var values = makeZerosTypedArray(sizeFromShape(shape), dtype);
return backend.makeTensorInfo(shape, dtype, values);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity$1(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
backend.incRef(x.dataId);
return {
dataId: x.dataId,
shape: x.shape,
dtype: x.dtype
};
}
var identityConfig = {
kernelName: Identity,
backendName: 'cpu',
kernelFunc: identity$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function real$1(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var real = backend.data.get(input.dataId).complexTensorInfos.real;
var realVal = backend.data.get(real.dataId).values; // When complex tensor is disposed, its underlying parts will be disposed too.
// Make new tensor out of the real value of the complex. This makes sure the
// value is still accessible even if complex tensor is disposed.
return backend.makeTensorInfo(real.shape, real.dtype, realVal);
}
var realConfig = {
kernelName: Real,
backendName: 'cpu',
kernelFunc: real$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cast$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var dtype = attrs.dtype; // Casting to complex64.
if (dtype === 'complex64') {
if (x.dtype === 'complex64') {
return identity$1({
inputs: {
x: x
},
backend: backend
});
}
var zerosTensorInfo = zeros$2(backend, x.shape, x.dtype);
var floatX = cast$2({
inputs: {
x: x
},
backend: backend,
attrs: {
dtype: 'float32'
}
});
var result = complex$1({
inputs: {
real: floatX,
imag: zerosTensorInfo
},
backend: backend
});
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
backend.disposeIntermediateTensorInfo(floatX);
return result;
} // Casting from complex64
if (x.dtype === 'complex64') {
var realPart = real$1({
inputs: {
input: x
},
backend: backend
});
var _result = cast$2({
inputs: {
x: realPart
},
backend: backend,
attrs: {
dtype: dtype
}
});
backend.disposeIntermediateTensorInfo(realPart);
return _result;
}
if (!hasEncodingLoss(x.dtype, dtype)) {
// We don't change the underlying data, since we cast to higher
// precision.
var _result2 = identity$1({
inputs: {
x: x
},
backend: backend
});
return {
dataId: _result2.dataId,
shape: _result2.shape,
dtype: dtype
};
}
if (dtype === 'int32') {
var values = backend.data.get(x.dataId).values;
var resultValues = Int32Array.from(values);
return backend.makeTensorInfo(x.shape, 'int32', resultValues);
}
if (dtype === 'bool') {
// This is essentially the result of notEqual(x, 0). We avoid using
// kernel notEqual to avoid circular dependency, i.e. binary_utils ->
// cast -> notEqual -> binary_utils.
var xVals = backend.data.get(x.dataId).values;
var zero = toTypedArray([0], x.dtype);
var _createSimpleBinaryKe = createSimpleBinaryKernelImpl(function (a, b) {
return a !== b ? 1 : 0;
})(x.shape, [], xVals, zero, 'bool'),
resultData = _createSimpleBinaryKe[0],
resultShape = _createSimpleBinaryKe[1];
return backend.makeTensorInfo(resultShape, 'bool', resultData);
}
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
}
var castConfig = {
kernelName: Cast,
backendName: 'cpu',
kernelFunc: cast$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Template that creates a `KernelFunc` for binary ops.
* @param name Kernel name.
* @param binaryKernelImpl A `SimpleBinaryKernelImpl` for the kernel.
* @param binaryKernelComplexImpl Optional. If exists, represents a
* `ComplexBinaryKernelImpl` for the kernel, will be used when input dtype
* is `complex64`.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the first input. This is mainly used in
* comparison kernels, such as Equal, Less, Greater, etc.
*/
function binaryKernelFunc(name, simpleImpl, complexImpl, dtype) {
if (complexImpl == null) {
return function (_ref) {
var inputs = _ref.inputs,
backend = _ref.backend;
var a = inputs.a,
b = inputs.b;
var cpuBackend = backend;
assertNotComplex([a, b], name);
var aVals = cpuBackend.data.get(a.dataId).values;
var bVals = cpuBackend.data.get(b.dataId).values;
var decodedAVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any
fromUint8ToStringArray(aVals) : aVals;
var decodedBVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any
fromUint8ToStringArray(bVals) : bVals;
var $dtype = dtype || a.dtype;
var _simpleImpl = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype),
resultData = _simpleImpl[0],
resultShape = _simpleImpl[1];
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
};
}
return function (_ref2) {
var inputs = _ref2.inputs,
backend = _ref2.backend;
var a = inputs.a,
b = inputs.b;
var cpuBackend = backend;
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
var $aComplex = cast$2({
inputs: {
x: a
},
backend: cpuBackend,
attrs: {
dtype: 'complex64'
}
});
var $aComplexVals = cpuBackend.data.get($aComplex.dataId);
var aReal = $aComplexVals.complexTensorInfos.real;
var aImag = $aComplexVals.complexTensorInfos.imag;
var aRealVals = cpuBackend.data.get(aReal.dataId).values;
var aImagVals = cpuBackend.data.get(aImag.dataId).values;
var $bComplex = cast$2({
inputs: {
x: b
},
backend: cpuBackend,
attrs: {
dtype: 'complex64'
}
});
var $bComplexVals = cpuBackend.data.get($bComplex.dataId);
var bReal = $bComplexVals.complexTensorInfos.real;
var bImag = $bComplexVals.complexTensorInfos.imag;
var bRealVals = cpuBackend.data.get(bReal.dataId).values;
var bImagVals = cpuBackend.data.get(bImag.dataId).values;
var _complexImpl = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals),
resultRealData = _complexImpl[0],
resultImagData = _complexImpl[1],
resultShape = _complexImpl[2];
var resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
var resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
var result = complex$1({
inputs: {
real: resultReal,
imag: resultImag
},
backend: cpuBackend
});
cpuBackend.disposeIntermediateTensorInfo($aComplex);
cpuBackend.disposeIntermediateTensorInfo($bComplex);
cpuBackend.disposeIntermediateTensorInfo(resultReal);
cpuBackend.disposeIntermediateTensorInfo(resultImag);
return result;
} else {
var aVals = cpuBackend.data.get(a.dataId).values;
var bVals = cpuBackend.data.get(b.dataId).values;
var $dtype = dtype || a.dtype;
var _simpleImpl2 = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype),
resultData = _simpleImpl2[0],
_resultShape = _simpleImpl2[1];
return cpuBackend.makeTensorInfo(_resultShape, $dtype, resultData);
}
};
}
/**
* Template that creates the complex type implementation for binary ops.
* Supports broadcast.
*/
function createComplexBinaryKernelImpl(op) {
return function (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) {
var resultShape = assertAndGetBroadcastShape(aShape, bShape);
var resultSize = sizeFromShape(resultShape);
var resultRank = resultShape.length;
var resultStrides = computeStrides(resultShape);
var resultRealVals = getTypedArrayFromDType('float32', resultSize);
var resultImagVals = getTypedArrayFromDType('float32', resultSize);
var aBroadcastDims = getBroadcastDims(aShape, resultShape);
var bBroadcastDims = getBroadcastDims(bShape, resultShape);
var aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
var bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
var aRank = aShape.length;
var aStrides = computeStrides(aShape);
var bRank = bShape.length;
var bStrides = computeStrides(bShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (var i = 0; i < resultRealVals.length; i++) {
var aIdx = i % aVals.length;
var bIdx = i % bVals.length;
var result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
resultRealVals[i] = result.real;
resultImagVals[i] = result.imag;
}
} else {
var _loop = function _loop(_i) {
var loc = indexToLoc(_i, resultRank, resultStrides);
var aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(function (d) {
return aLoc[d] = 0;
});
var aIndex = locToIndex(aLoc, aRank, aStrides);
var bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(function (d) {
return bLoc[d] = 0;
});
var bIndex = locToIndex(bLoc, bRank, bStrides);
var opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
resultRealVals[_i] = opResult.real;
resultImagVals[_i] = opResult.imag;
};
for (var _i = 0; _i < resultRealVals.length; _i++) {
_loop(_i);
}
}
return [resultRealVals, resultImagVals, resultShape];
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a + b;
});
var addComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
return {
real: aReal + bReal,
imag: aImag + bImag
};
});
var add$4 = binaryKernelFunc(Add, addImpl, addComplexImpl);
var addConfig = {
kernelName: Add,
backendName: 'cpu',
kernelFunc: add$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
var weightsSize = sizeFromShape(weightsShape);
var outVals = makeZerosTypedArray(size, weightsDtype);
for (var i = 0; i < xVals.length; i++) {
var value = xVals[i];
if (value < 0) {
throw new Error('Input x must be non-negative!');
}
if (value >= size) {
continue;
}
if (weightsSize > 0) {
outVals[value] += weightsVals[i];
} else {
outVals[value] += 1;
}
}
return outVals;
}
function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput) {
if (binaryOutput === void 0) {
binaryOutput = false;
}
var numRows = xBuf.shape[0];
var numCols = xBuf.shape[1];
var outBuf = buffer([numRows, size], weightsBuf.dtype);
for (var i = 0; i < numRows; i++) {
for (var j = 0; j < numCols; j++) {
var value = xBuf.get(i, j);
if (value < 0) {
throw new Error('Input x must be non-negative!');
}
if (value >= size) {
continue;
}
if (binaryOutput) {
outBuf.set(1, i, value);
} else {
if (weightsBuf.size > 0) {
outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
} else {
outBuf.set(outBuf.get(i, value) + 1, i, value);
}
}
}
}
return outBuf;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Template that creates implementation for unary op.
*/
function createSimpleUnaryImpl(op) {
return function (values, dtype, attrs) {
var newValues = getTypedArrayFromDType(dtype, values.length);
for (var i = 0; i < values.length; ++i) {
newValues[i] = op(values[i], attrs);
}
return newValues;
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Template that creates a `KernelFunc` for unary ops.
* @param name Kernel name.
* @param op A `SimpleUnaryOperation` for the kernel.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the input. This is mainly used in certain
* kernels that return bool type, such as isFinite, isInf, etc.
*/
function unaryKernelFunc(name, op, dtype) {
return function (_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var x = inputs.x;
assertNotComplex(x, name);
if (x.dtype === 'string' || dtype === 'string') {
throw new Error('unaryKernelFunc does not support string input/output');
}
var cpuBackend = backend;
var values = cpuBackend.data.get(x.dataId).values;
var xSize = sizeFromShape(x.shape);
var $dtype = dtype || x.dtype;
var newValues = getArrayFromDType($dtype, xSize);
for (var i = 0; i < xSize; ++i) {
newValues[i] = op(values[i], attrs);
}
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
/**
* Template that creates a `KernelFunc` for unary ops from the given
* `SimpleUnaryImpl`..
* @param name Kernel name.
* @param unaryImpl A `SimpleUnaryImpl` that implements the op.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the input. This is mainly used in certain
* kernels that return bool type, such as isFinite, isInf, etc.
*/
function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
return function (_ref2) {
var inputs = _ref2.inputs,
attrs = _ref2.attrs,
backend = _ref2.backend;
var x = inputs.x;
assertNotComplex(x, name);
if (x.dtype === 'string' || dtype === 'string') {
throw new Error('unaryKernelFunc does not support string input/output');
}
var cpuBackend = backend;
var values = cpuBackend.data.get(x.dataId).values;
var $dtype = dtype || x.dtype;
var newValues = unaryImpl(values, $dtype, attrs);
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ceilImpl = createSimpleUnaryImpl(function (xi) {
return Math.ceil(xi);
});
var ceil$4 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
var ceilConfig = {
kernelName: Ceil,
backendName: 'cpu',
kernelFunc: ceil$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concatImpl(inputs, outShape, dtype, simplyConcat) {
var outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
if (simplyConcat && dtype !== 'string') {
// Use built-in TypedArray.set() method for speed.
var offset = 0;
inputs.forEach(function (input) {
var size = sizeFromShape(input.shape);
outVals.set(input.vals, offset);
offset += size;
});
} else {
var colOffset = 0;
inputs.forEach(function (input) {
var decodedData = dtype === 'string' ? fromUint8ToStringArray(input.vals) : input.vals;
var tIdx = 0;
for (var row = 0; row < input.shape[0]; ++row) {
var resIdx = row * outShape[1] + colOffset;
for (var col = 0; col < input.shape[1]; ++col) {
outVals[resIdx + col] = decodedData[tIdx++];
}
}
colOffset += input.shape[1];
});
}
return outVals;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var equalImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a === b ? 1 : 0;
});
var equal$1 = binaryKernelFunc(Equal, equalImpl, null
/* complexImpl */
, 'bool');
var equalConfig = {
kernelName: Equal,
backendName: 'cpu',
kernelFunc: equal$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expImpl = createSimpleUnaryImpl(function (xi) {
return Math.exp(xi);
});
var exp$4 = unaryKernelFuncFromImpl(Exp, expImpl);
var expConfig = {
kernelName: Exp,
backendName: 'cpu',
kernelFunc: exp$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var expm1Impl = createSimpleUnaryImpl(function (xi) {
return Math.expm1(xi);
});
var expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
var expm1Config = {
kernelName: Expm1,
backendName: 'cpu',
kernelFunc: expm1$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorImpl = createSimpleUnaryImpl(function (xi) {
return Math.floor(xi);
});
var floor$b = unaryKernelFuncFromImpl(Floor, floorImpl);
var floorConfig = {
kernelName: Floor,
backendName: 'cpu',
kernelFunc: floor$b
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
var outBuf = buffer([numSlices, sliceSize], dtype);
for (var i = 0; i < numSlices; i++) {
var index = [];
var flattenIndex = 0;
for (var j = 0; j < sliceRank; j++) {
var dim = indicesData[i * sliceRank + j];
flattenIndex += dim * strides[j];
index.push(dim);
}
if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
throw new Error("Invalid indices: " + index + " does not index into " + paramsShape);
}
for (var k = 0; k < sliceSize; k++) {
outBuf.values[i * sliceSize + k] = paramsBuf.get.apply(paramsBuf, paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
}
}
return outBuf;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
var outBuf = buffer(flattenOutputShape, xBuf.dtype);
for (var i = 0; i < outBuf.size; ++i) {
var newLoc = outBuf.indexToLoc(i);
var originalLoc = newLoc.slice();
var batchIdx = originalLoc[0];
var indicesIdx = originalLoc[2];
var indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
originalLoc[2] = indicesBuf.values[indicesIndex];
var originalIndex = xBuf.locToIndex(originalLoc);
outBuf.values[i] = xBuf.values[originalIndex];
}
return outBuf;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var greaterImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a > b ? 1 : 0;
});
var greater$2 = binaryKernelFunc(Greater, greaterImpl, null
/* complexImpl */
, 'bool');
var greaterConfig = {
kernelName: Greater,
backendName: 'cpu',
kernelFunc: greater$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var greaterEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a >= b ? 1 : 0;
});
var greaterEqual$1 = binaryKernelFunc(GreaterEqual, greaterEqualImpl, null
/* complexImpl */
, 'bool');
var greaterEqualConfig = {
kernelName: GreaterEqual,
backendName: 'cpu',
kernelFunc: greaterEqual$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lessImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a < b ? 1 : 0;
});
var less$2 = binaryKernelFunc(Less, lessImpl, null
/* complexImpl */
, 'bool');
var lessConfig = {
kernelName: Less,
backendName: 'cpu',
kernelFunc: less$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lessEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a <= b ? 1 : 0;
});
var lessEqual$1 = binaryKernelFunc(LessEqual, lessEqualImpl, null
/* complexImpl */
, 'bool');
var lessEqualConfig = {
kernelName: LessEqual,
backendName: 'cpu',
kernelFunc: lessEqual$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function linSpaceImpl(start, stop, num) {
var step = (stop - start) / (num - 1);
var values = makeZerosTypedArray(num, 'float32');
values[0] = start;
for (var i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step;
}
return values;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logImpl = createSimpleUnaryImpl(function (xi) {
return Math.log(xi);
});
var log$b = unaryKernelFuncFromImpl(Log, logImpl);
var logConfig = {
kernelName: Log,
backendName: 'cpu',
kernelFunc: log$b
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxImpl(aVals, reduceSize, outShape, dtype) {
var vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var max = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (Number.isNaN(value) || value > max) {
// comparison with NaN always return false
max = value;
}
}
vals[i] = max;
}
return vals;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maximumImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
return Math.max(aValue, bValue);
});
var maximum$3 = binaryKernelFunc(Maximum, maximumImpl);
var maximumConfig = {
kernelName: Maximum,
backendName: 'cpu',
kernelFunc: maximum$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var minimumImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
return Math.min(aValue, bValue);
});
var minimum$3 = binaryKernelFunc(Minimum, minimumImpl);
var minimumConfig = {
kernelName: Minimum,
backendName: 'cpu',
kernelFunc: minimum$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var multiplyImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
return aValue * bValue;
});
var multiplyComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
return {
real: aReal * bReal - aImag * bImag,
imag: aReal * bImag + aImag * bReal
};
});
var multiply$3 = binaryKernelFunc(Multiply, multiplyImpl, multiplyComplexImpl);
var multiplyConfig = {
kernelName: Multiply,
backendName: 'cpu',
kernelFunc: multiply$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function negImpl(xVals, xShape, xDtype) {
var minusOne = createScalarValue(-1, xDtype);
return multiplyImpl([], xShape, minusOne, xVals, xDtype);
}
function neg$1(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
assertNotComplex(x, 'neg');
var xVals = backend.data.get(x.dataId).values;
var _negImpl = negImpl(xVals, x.shape, x.dtype),
res = _negImpl[0],
newShape = _negImpl[1];
return backend.makeTensorInfo(newShape, x.dtype, res);
}
var negConfig = {
kernelName: Neg,
backendName: 'cpu',
kernelFunc: neg$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var notEqualImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a !== b ? 1 : 0;
});
var notEqual$1 = binaryKernelFunc(NotEqual, notEqualImpl, null
/* complexOp */
, 'bool');
var notEqualConfig = {
kernelName: NotEqual,
backendName: 'cpu',
kernelFunc: notEqual$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transposeImpl(xVals, xShape, dtype, perm, newShape) {
var xRank = xShape.length;
var xSize = sizeFromShape(xShape);
var xStrides = computeStrides(xShape);
var newStrides = computeStrides(newShape);
var result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
for (var i = 0; i < xSize; ++i) {
var loc = indexToLoc(i, xRank, xStrides); // Permute location.
var newLoc = new Array(loc.length);
for (var _i = 0; _i < newLoc.length; _i++) {
newLoc[_i] = loc[perm[_i]];
}
var newIndex = locToIndex(newLoc, xRank, newStrides);
result[newIndex] = xVals[i];
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transpose$1(args) {
var inputs = args.inputs,
attrs = args.attrs,
backend = args.backend;
var x = inputs.x;
var perm = attrs.perm;
assertNotComplex(x, 'transpose');
var xRank = x.shape.length;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
var values = backend.data.get(x.dataId).values;
var result = transposeImpl(values, x.shape, x.dtype, perm, newShape);
var dataId = backend.write(result, newShape, x.dtype);
return {
dataId: dataId,
shape: newShape,
dtype: x.dtype
};
}
var transposeConfig = {
kernelName: Transpose,
backendName: 'cpu',
kernelFunc: transpose$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function prodImpl(xShape, xDtype, xVals, reductionAxes) {
var _backend_util$compute = computeOutAndReduceShapes(xShape, reductionAxes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outDtype = upcastType(xDtype, 'int32');
var outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
var reduceSize = sizeFromShape(reduceShape);
for (var i = 0; i < outVals.length; ++i) {
var offset = i * reduceSize;
var _prod = 1;
for (var j = 0; j < reduceSize; ++j) {
_prod *= xVals[offset + j];
}
outVals[i] = _prod;
}
return {
outVals: outVals,
outShape: outShape,
outDtype: outDtype
};
}
function prod$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
assertNotComplex(x, 'prod');
var xRank = x.shape.length;
var axes = parseAxisParam(axis, x.shape);
var permutation = getAxesPermutation(axes, xRank);
var reductionAxes = axes;
var permutedX = x;
var intermediateTensorInfos = [];
if (permutation != null) {
permutedX = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutation
}
});
intermediateTensorInfos.push(permutedX);
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
}
var xVals = backend.data.get(permutedX.dataId).values;
var _prodImpl = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes),
outVals = _prodImpl.outVals,
outShape = _prodImpl.outShape,
outDtype = _prodImpl.outDtype;
var resultShape = outShape;
if (keepDims) {
resultShape = expandShapeToKeepDim(outShape, axes);
}
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return backend.makeTensorInfo(resultShape, outDtype, outVals);
}
var prodConfig = {
kernelName: Prod,
backendName: 'cpu',
kernelFunc: prod$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function rangeImpl(start, stop, step, dtype) {
var sameStartStop = start === stop;
var increasingRangeNegativeStep = start < stop && step < 0;
var decreasingRangePositiveStep = stop < start && step > 1;
if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) {
return makeZerosTypedArray(0, dtype);
}
var numElements = Math.abs(Math.ceil((stop - start) / step));
var values = makeZerosTypedArray(numElements, dtype);
if (stop < start && step === 1) {
// Auto adjust the step's sign if it hasn't been set
// (or was set to 1)
step = -1;
}
values[0] = start;
for (var i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step;
}
return values;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rsqrtImpl = createSimpleUnaryImpl(function (xi) {
return 1 / Math.sqrt(xi);
});
var rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
var rsqrtConfig = {
kernelName: Rsqrt,
backendName: 'cpu',
kernelFunc: rsqrt$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sigmoidImpl = createSimpleUnaryImpl(function (xi) {
return 1 / (1 + Math.exp(-xi));
});
var sigmoid$1 = unaryKernelFunc(Sigmoid, function (xi) {
return 1 / (1 + Math.exp(-xi));
});
var sigmoidConfig = {
kernelName: Sigmoid,
backendName: 'cpu',
kernelFunc: sigmoid$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sliceImpl(vals, begin, size, shape, dtype) {
var isContinous = isSliceContinous(shape, begin, size);
var length = sizeFromShape(size);
var xStrides = computeStrides(shape);
if (isContinous) {
var flatOffset = computeFlatOffset(begin, xStrides);
if (dtype === 'string') {
return vals.slice(flatOffset, flatOffset + length);
}
return vals.subarray(flatOffset, flatOffset + length);
}
var decodedData = dtype === 'string' ? fromUint8ToStringArray(vals) : vals;
var inBuf = buffer(shape, dtype, decodedData);
var outBuf = buffer(size, dtype);
for (var i = 0; i < outBuf.size; ++i) {
var outLoc = outBuf.indexToLoc(i);
var inLoc = outLoc.map(function (idx, j) {
return idx + begin[j];
});
outBuf.set.apply(outBuf, [inBuf.get.apply(inBuf, inLoc)].concat(outLoc));
}
if (dtype === 'string') {
return fromStringArrayToUint8(outBuf.values);
}
return outBuf.values;
}
function slice$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var begin = attrs.begin,
size = attrs.size;
assertNotComplex(x, 'slice');
var _slice_util$parseSlic = parseSliceParams(x, begin, size),
$begin = _slice_util$parseSlic[0],
$size = _slice_util$parseSlic[1];
assertParamsValid(x, $begin, $size);
var vals = backend.data.get(x.dataId).values;
var outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
return backend.makeTensorInfo($size, x.dtype, outVals);
}
var sliceConfig = {
kernelName: Slice,
backendName: 'cpu',
kernelFunc: slice$3
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
var indicesCount = indicesShape[0];
var denseRows = denseShape[0];
var emptyRowIndicator = new Array(denseRows);
var reverseIndexMap = new Array(indicesCount);
var rank = indicesShape[1];
if (denseRows === 0) {
if (indicesCount !== 0) {
throw new Error("Received SparseTensor with denseShape[0] = 0 but\n indices.shape[0] = " + indicesCount);
}
var outputIndices = getArrayFromDType(indicesDType, 0);
var outputValues = getArrayFromDType(valuesDType, 0);
return [outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap];
}
var rowsAreOrdered = true;
var lastIndicesRow = 0;
var csrOffset = new Array(denseRows).fill(0);
for (var i = 0; i < indicesCount; ++i) {
// indices is a 2d tensor with shape of [N, rank]
var row = indices[i * rank];
if (row < 0) {
throw new Error("indices(" + i + ", 0) is invalid: " + row + " < 0");
}
if (row >= denseRows) {
throw new Error("indices(" + i + ", 0) is invalid: " + row + " >= " + denseRows);
}
++csrOffset[row];
rowsAreOrdered = rowsAreOrdered && row >= lastIndicesRow;
lastIndicesRow = row;
}
var allRowsFull = true;
for (var _row = 0; _row < denseRows; ++_row) {
// csrOffset here describes the number of elements in this dense row
var rowEmpty = csrOffset[_row] === 0;
emptyRowIndicator[_row] = rowEmpty;
allRowsFull = allRowsFull && !rowEmpty; // In filled version, each row has at least one element.
csrOffset[_row] = Math.max(csrOffset[_row], 1); // Update csrOffset to represent the number of elements up to and
// including denseRows + 1:
// csrOffset[0] == #{elements of row 0}
// csrOffset[1] == #{elements of row 1} + #{elements of row 0}
// ..
// csrOffset[i] == starting index for elements in row i + 1.
if (_row > 0) {
csrOffset[_row] += csrOffset[_row - 1];
}
}
if (allRowsFull && rowsAreOrdered) {
var _outputIndices = indices;
var _outputValues = values;
for (var _i = 0; _i < indicesCount; ++_i) {
reverseIndexMap[_i] = _i;
}
return [_outputIndices, [indicesCount, rank], _outputValues, emptyRowIndicator, reverseIndexMap];
} else {
var fullIndicesCount = csrOffset[denseRows - 1];
var _outputIndices2 = getArrayFromDType(indicesDType, fullIndicesCount * rank);
var _outputValues2 = getArrayFromDType(valuesDType, fullIndicesCount);
var filledCount = new Array(denseRows).fill(0); // Fill in values for rows that are not missing
for (var _i2 = 0; _i2 < indicesCount; ++_i2) {
// indices is a 2d tensor with shape of [N, rank]
var _row2 = indices[_i2 * rank];
var offset = filledCount[_row2];
var outputI = (_row2 === 0 ? 0 : csrOffset[_row2 - 1]) + offset;
filledCount[_row2]++; // Increment the filled count for this row.
for (var j = 0; j < rank; ++j) {
// indices and outputIndices are 2d tensors with shape of [N, rank]
_outputIndices2[outputI * rank + j] = indices[_i2 * rank + j];
}
_outputValues2[outputI] = values[_i2]; // We'll need this reverse index map to backprop correctly.
reverseIndexMap[_i2] = outputI;
} // Fill in values for rows that are missing
for (var _row3 = 0; _row3 < denseRows; ++_row3) {
var rowCount = filledCount[_row3];
if (rowCount === 0) {
// We haven't filled this row
var startingIndex = _row3 === 0 ? 0 : csrOffset[_row3 - 1]; // Remaining index values were set to zero already.
// Just need to set the row index in the right location.
// outputIndices is a 2d tensor with shape of [N, rank]
_outputIndices2[startingIndex * rank + 0] = _row3;
for (var col = 1; col < rank; ++col) {
_outputIndices2[startingIndex * rank + col] = 0;
}
_outputValues2[startingIndex] = defaultValue;
}
}
return [_outputIndices2, [fullIndicesCount, rank], _outputValues2, emptyRowIndicator, reverseIndexMap];
}
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
var denseSize = sizeFromShape(inputShape);
var nnz = inputIndicesShape[0];
var outputRank = targetShape.length; // Compute the output shape. Determine product of specified dimensions, and
// find the index of the unspecified one.
var outputShape = [];
var product = 1;
var unknownIndex = -1;
for (var d = 0; d < outputRank; ++d) {
var size = targetShape[d];
if (size === -1) {
if (unknownIndex !== -1) {
throw new Error("only one output dimension may be -1, not both " + unknownIndex + " and " + d);
}
unknownIndex = d;
outputShape.push(1);
} else {
if (size < 0) {
throw new Error("size " + d + " must be non-negative, not " + size);
}
product *= size;
outputShape.push(size);
}
}
if (unknownIndex !== -1) {
if (product <= 0) {
throw new Error('reshape cannot infer the missing ' + 'input size for an empty tensor unless all ' + 'specified input sizes are non-zero');
}
var missing = Math.trunc(denseSize / product);
if (product * missing !== denseSize) {
throw new Error("Input to reshape is a SparseTensor with " + denseSize + "\n dense values, but the requested shape requires a multiple of " + product + ". inputShape=" + inputShape + " outputShape= " + outputShape);
}
outputShape[unknownIndex] = missing;
}
var outputSize = sizeFromShape(outputShape);
if (outputSize !== denseSize) {
throw new Error("Input to reshape is a tensor with " + denseSize + " dense values, but the requested shape has " + outputSize + ". inputShape=" + inputShape + " outputShape=" + outputShape);
}
var inputRank = inputShape.length;
var inputStrides = [];
if (inputRank > 0) {
inputStrides[inputRank - 1] = 1;
for (var _d = inputRank - 2; _d >= 0; --_d) {
inputStrides[_d] = inputStrides[_d + 1] * inputShape[_d + 1];
}
}
var outputStrides = [];
if (outputRank > 0) {
outputStrides[outputRank - 1] = 1;
for (var _d2 = outputRank - 2; _d2 >= 0; --_d2) {
outputStrides[_d2] = outputStrides[_d2 + 1] * outputShape[_d2 + 1];
}
}
var newIndices = getArrayFromDType(inputDType, nnz * outputRank);
for (var i = 0; i < nnz; ++i) {
var id = 0;
for (var j = 0; j < inputRank; ++j) {
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
id += inputIndices[i * inputRank + j] * inputStrides[j];
}
for (var _j = 0; _j < outputRank; ++_j) {
// newIndices is a 2d tensor with shape of [nnz, outputRank]
newIndices[i * outputRank + _j] = Math.trunc(id / outputStrides[_j]);
id %= outputStrides[_j];
}
}
return [newIndices, [nnz, outputRank], outputShape];
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean, defaultValue) {
if (isMean === void 0) {
isMean = false;
}
if (defaultValue === void 0) {
defaultValue = 0;
}
var numIndices = indices.length;
if (numIndices !== segmentIds.length) {
throw new Error("segmentIds and indices should have same size.");
} // Flatten the array to two dimensions
var inputFlat = [inputShape[0], input.length / inputShape[0]];
var numCol = inputFlat[1]; // Note that the current implementation assumes that segmentIds values are
// sorted.
var lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
var outputRows = lastSegmentIdPlusOne;
if (outputRows < 0) {
throw new Error("segment ids must be >= 0");
}
var outputShape = inputShape.slice();
outputShape[0] = outputRows;
var outputLength = outputShape.reduce(function (product, value) {
return product * value;
}, 1); // Output array is initialized with the value 0 by default.
var output = getArrayFromDType(inputDType, outputLength); // Note that we do not initialize the output buffer with a default value, so
// we need to explicitly set missing indices to the default value.
if (numIndices === 0) {
if (outputRows > 0) {
output.fill(defaultValue);
}
return [output, outputShape];
}
if (outputRows <= 0) {
throw new Error("segment ids must be >= 0");
}
var start = 0,
end = 1; // Index from which the output is not initialized.
var uninitializedIndex = 0;
var outIndex = segmentIds[start];
while (true) {
// We initialize nextIndex to 0 to avoid may be uninitialized warning
var nextIndex = 0;
if (end < numIndices) {
nextIndex = segmentIds[end];
if (outIndex === nextIndex) {
++end;
continue;
} // We have a new segment here. Verify that the segment ids are growing.
if (outIndex >= nextIndex) {
throw new Error("segment ids are not increasing");
}
}
if (outIndex < 0 || outIndex >= outputRows) {
throw new Error("Segment id " + outIndex + " out of range [0, " + outputRows + "), possibly because segmentIds input is not sorted.");
} // If there is a gap between two indices, we need to set that gap to the
// default value.
if (outIndex > uninitializedIndex) {
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
}
for (var i = start; i < end; ++i) {
var index = indices[i];
if (index < 0 || index >= inputFlat[0]) {
throw new Error("Bad: indices[" + i + "] == " + indices[i] + " out of range [0, " + inputFlat[0] + ")");
}
for (var j = 0; j < numCol; j++) {
output[outIndex * numCol + j] += input[index * numCol + j];
}
}
if (isMean) {
for (var _j = 0; _j < numCol; _j++) {
output[outIndex * numCol + _j] /= end - start;
}
}
start = end;
++end;
uninitializedIndex = outIndex + 1;
outIndex = nextIndex;
if (end > numIndices) {
break;
}
} // Fill the gap at the end with the default value.
if (uninitializedIndex < outputRows) {
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
}
return [output, outputShape];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sqrtImpl = createSimpleUnaryImpl(function (xi) {
return Math.sqrt(xi);
});
var sqrt$4 = unaryKernelFunc(Sqrt, function (xi) {
return Math.sqrt(xi);
});
var sqrtConfig = {
kernelName: Sqrt,
backendName: 'cpu',
kernelFunc: sqrt$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squaredDifferenceImpl = createSimpleBinaryKernelImpl(function (a, b) {
var diff = a - b;
return diff * diff;
});
var squaredDifference$1 = binaryKernelFunc(SquaredDifference, squaredDifferenceImpl);
var squaredDifferenceConfig = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: squaredDifference$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stridedSliceImpl(outShape, xBuf, strides, begin) {
var outBuf = buffer(outShape, xBuf.dtype);
for (var i = 0; i < outBuf.size; i++) {
var loc = outBuf.indexToLoc(i);
var newLoc = new Array(loc.length);
for (var j = 0; j < newLoc.length; j++) {
newLoc[j] = loc[j] * strides[j] + begin[j];
}
outBuf.set.apply(outBuf, [xBuf.get.apply(xBuf, newLoc)].concat(loc));
}
return outBuf;
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* The StringNGramsOp class creates ngrams from ragged string data.
* The constructor contains all attributes related to the operation such as
* padding widths and strings, and the compute function can be used to
* compute the ngrams for different ragged tensor inputs.
*/
var StringNGramsOp = /*#__PURE__*/function () {
function StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
this.separator = encodeString(separator);
this.nGramWidths = nGramWidths;
this.leftPad = encodeString(leftPad);
this.rightPad = encodeString(rightPad);
this.padWidth = padWidth;
this.preserveShort = preserveShortSequences;
}
var _proto = StringNGramsOp.prototype;
_proto.getPadWidth = function getPadWidth(nGramWidth) {
// Ngrams can be padded with either a fixed pad width or a dynamic pad
// width depending on the 'padWidth' arg, but in no case should the padding
// ever be wider than 'nGramWidth' - 1.
return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
};
_proto.getNumNGrams = function getNumNGrams(length, nGramWidth) {
var padWidth = this.getPadWidth(nGramWidth);
return Math.max(0, length + 2 * padWidth - nGramWidth + 1);
};
_proto.createNGrams = function createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
var _this = this;
var _loop = function _loop(nGramIndex) {
var padWidth = _this.getPadWidth(nGramWidth);
var leftPadding = Math.max(0, padWidth - nGramIndex);
var rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
var numTokens = nGramWidth - (leftPadding + rightPadding);
var dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth); // Calculate the total expected size of the nGram so we can reserve the
// correct amount of space in the string.
var nGramSize = 0; // Size of the left padding.
nGramSize += leftPadding * _this.leftPad.length; // Size of the tokens.
for (var n = 0; n < numTokens; ++n) {
nGramSize += data[dataStartIndex + n].length;
} // Size of the right padding.
nGramSize += rightPadding * _this.rightPad.length; // Size of the separators.
var numSeparators = leftPadding + rightPadding + numTokens - 1;
nGramSize += numSeparators * _this.separator.length; // Build the nGram.
output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
var nGram = output[outputStartIndex + nGramIndex];
var nextNGramIndex = 0;
var appendToNGram = function appendToNGram(str) {
return str.forEach(function (value) {
return nGram[nextNGramIndex++] = value;
});
};
for (var _n = 0; _n < leftPadding; ++_n) {
appendToNGram(_this.leftPad);
appendToNGram(_this.separator);
} // Only output first numTokens - 1 pairs of data and separator
for (var _n2 = 0; _n2 < numTokens - 1; ++_n2) {
appendToNGram(data[dataStartIndex + _n2]);
appendToNGram(_this.separator);
} // Handle case when there are no tokens or no right padding as these
// can result in consecutive separators.
if (numTokens > 0) {
// If we have tokens, then output last and then pair each separator
// with the right padding that follows, to ensure nGram ends either with
// the token or with the right pad.
appendToNGram(data[dataStartIndex + numTokens - 1]);
for (var _n3 = 0; _n3 < rightPadding; ++_n3) {
appendToNGram(_this.separator);
appendToNGram(_this.rightPad);
}
} else {
// If we don't have tokens, then the last item inserted into the nGram
// has been the separator from the left padding loop above. Hence,
// output right pad and separator and make sure to finish with a
// padding, not a separator.
for (var _n4 = 0; _n4 < rightPadding - 1; ++_n4) {
appendToNGram(_this.rightPad);
appendToNGram(_this.separator);
}
appendToNGram(_this.rightPad);
}
};
for (var nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
_loop(nGramIndex);
}
} // Data and splits together form the definition of the ragged tensor,
// where data is 1 dimensional and contains the values of the tensor
// and splits denotes the indices at which each row starts.
;
_proto.compute = function compute(data, splits) {
var _this2 = this;
// Validate that the splits are valid indices into data, only if there are
// splits specified.
var inputDataSize = data.length;
var splitsSize = splits.length;
if (splitsSize > 0) {
var prevSplit = splits[0];
if (prevSplit !== 0) {
throw new Error("First split value must be 0, got " + prevSplit);
}
for (var i = 1; i < splitsSize; ++i) {
var validSplits = splits[i] >= prevSplit;
validSplits = validSplits && splits[i] <= inputDataSize;
if (!validSplits) {
throw new Error("Invalid split value " + splits[i] + ", must be in [" + prevSplit + ", " + inputDataSize + "]");
}
prevSplit = splits[i];
}
if (prevSplit !== inputDataSize) {
throw new Error("Last split value must be data size. Expected " + inputDataSize + ", got " + prevSplit);
}
}
var numBatchItems = splitsSize - 1;
var nGramsSplits = getArrayFromDType('int32', splitsSize); // If there is no data or size, return an empty ragged tensor.
if (inputDataSize === 0 || splitsSize === 0) {
var empty = new Array(inputDataSize);
for (var _i = 0; _i <= numBatchItems; ++_i) {
nGramsSplits[_i] = 0;
}
return [empty, nGramsSplits];
}
nGramsSplits[0] = 0;
var _loop2 = function _loop2(_i2) {
var length = splits[_i2] - splits[_i2 - 1];
var numNGrams = 0;
_this2.nGramWidths.forEach(function (nGramWidth) {
numNGrams += _this2.getNumNGrams(length, nGramWidth);
});
if (_this2.preserveShort && length > 0 && numNGrams === 0) {
numNGrams = 1;
}
nGramsSplits[_i2] = nGramsSplits[_i2 - 1] + numNGrams;
};
for (var _i2 = 1; _i2 <= numBatchItems; ++_i2) {
_loop2(_i2);
}
var nGrams = new Array(nGramsSplits[numBatchItems]);
var _loop3 = function _loop3(_i3) {
var splitIndex = splits[_i3];
var outputStartIdx = nGramsSplits[_i3];
_this2.nGramWidths.forEach(function (nGramWidth) {
var length = splits[_i3 + 1] - splits[_i3];
var numNGrams = _this2.getNumNGrams(length, nGramWidth);
_this2.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
outputStartIdx += numNGrams;
}); // If we're preserving short sequences, check to see if no sequence was
// generated by comparing the current output start idx to the original
// one (nGramSplitsdata). If no ngrams were generated, then they will
// be equal (since we increment outputStartIdx by numNGrams every
// time we create a set of ngrams.)
if (_this2.preserveShort && outputStartIdx === nGramsSplits[_i3]) {
var dataLength = splits[_i3 + 1] - splits[_i3]; // One legitimate reason to not have any ngrams when this.preserveShort
// is true is if the sequence itself is empty. In that case, move on.
if (dataLength === 0) {
return "continue";
} // We don't have to worry about dynamic padding sizes here: if padding
// was dynamic, every sequence would have had sufficient padding to
// generate at least one nGram.
var nGramWidth = dataLength + 2 * _this2.padWidth;
var numNGrams = 1;
_this2.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
}
};
for (var _i3 = 0; _i3 < numBatchItems; ++_i3) {
var _ret = _loop3(_i3);
if (_ret === "continue") continue;
}
return [nGrams, nGramsSplits];
};
return StringNGramsOp;
}();
function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences).compute(data, dataSplits);
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function split$4(str, delimiters, skipEmpty, result) {
if (!str.length) {
return;
} // When the delimiter is empty, the input is split into individual characters.
if (delimiters.length === 0) {
for (var i = 0; i < str.length; ++i) {
result.push(str.subarray(i, i + 1));
}
return;
} // When there is one delimiter, the input is split only at that delimiter.
if (delimiters.length === 1) {
var delimiter = delimiters[0];
var f = str.indexOf(delimiter);
while (f !== -1) {
var token = str.subarray(0, f);
if (!skipEmpty || token.length !== 0) {
result.push(token);
}
str = str.subarray(f + 1);
f = str.indexOf(delimiter);
}
if (!skipEmpty || str.length !== 0) {
result.push(str);
}
return;
} // When there are multiple delimiters, the input is split at every instance
// one of the delimiters appears.
var tokenStart = 0;
for (var _i = 0; _i < str.length + 1; _i++) {
if (_i === str.length || delimiters.indexOf(str[_i]) !== -1) {
var _token = str.subarray(tokenStart, _i);
if (!skipEmpty || _token.length !== 0) {
result.push(_token);
}
tokenStart = _i + 1;
}
}
}
function stringSplitImpl(input, delimiter, skipEmpty) {
var batchSize = input.length; // Empty delimiter means split the input character by character.
var tokens = [];
var outputSize = 0;
var maxNumEntries = 0;
var numIndices = new Array(batchSize);
for (var i = 0; i < batchSize; ++i) {
var prevTokensLength = tokens.length;
split$4(input[i], delimiter, skipEmpty, tokens);
var nEntries = tokens.length - prevTokensLength;
numIndices[i] = nEntries;
outputSize += nEntries;
maxNumEntries = Math.max(maxNumEntries, nEntries);
}
var indices = getArrayFromDType('int32', outputSize * 2);
var values = new Array(outputSize);
var shape = [batchSize, maxNumEntries];
var c = 0;
for (var _i2 = 0; _i2 < batchSize; ++_i2) {
for (var j = 0; j < numIndices[_i2]; ++j) {
// indices is a 2d tensor with shape of [outputSize, 2]
indices[c * 2] = _i2;
indices[c * 2 + 1] = j;
values[c] = tokens[c];
++c;
}
}
return [indices, values, shape];
}
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringToHashBucketFastImpl(input, numBuckets) {
var output = getArrayFromDType('int32', input.length);
for (var i = 0; i < input.length; ++i) {
output[i] = fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
}
return output;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var subImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
return aValue - bValue;
});
var subComplexImpl = createComplexBinaryKernelImpl(function (aReal, aImag, bReal, bImag) {
return {
real: aReal - bReal,
imag: aImag - bImag
};
});
var sub$1 = binaryKernelFunc(Sub, subImpl, subComplexImpl);
var subConfig = {
kernelName: Sub,
backendName: 'cpu',
kernelFunc: sub$1
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* An implementation of the tile kernel shared between webgl and cpu for string
* tensors only.
*/
function tileImpl(xBuf, reps) {
var newShape = new Array(xBuf.rank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = xBuf.shape[i] * reps[i];
}
var result = buffer(newShape, xBuf.dtype);
for (var _i = 0; _i < result.values.length; ++_i) {
var newLoc = result.indexToLoc(_i);
var originalLoc = new Array(xBuf.rank);
for (var j = 0; j < originalLoc.length; j++) {
originalLoc[j] = newLoc[j] % xBuf.shape[j];
}
var originalIndex = xBuf.locToIndex(originalLoc);
result.values[_i] = xBuf.values[originalIndex];
}
return result;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var comparePair = function comparePair(a, b) {
var valueDiff = b.value - a.value;
return valueDiff === 0 ? a.index - b.index : valueDiff;
};
/**
* Partitions array where all elements smaller than the (k+1) smallest element
* are found to the left of it, and all larger to the right of it.
* Based on the Floyd-Rivest Algorithm, ref:
* https://en.wikipedia.org/wiki/Floyd%E2%80%93Rivest_algorithm
* @param array: Array to partition
* @param left: Left index for the interval
* @param right: Right index for the interval
* @param k: Desired index value, where array[k] is the (k+1)th smallest element
* when left = 0
*/
function select(array, k, left, right) {
if (left === void 0) {
left = 0;
}
if (right === void 0) {
right = array.length - 1;
}
while (right > left) {
// Use select recursively to sample a smaller set of size s
// the arbitrary constants 600 and 0.5 are used in the original
// version to minimize execution time.
if (right - left > 600) {
var n = right - left + 1;
var _i = k - left + 1;
var z = Math.log(n);
var s = 0.5 * Math.exp(2 * z / 3);
var sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(_i - n / 2);
var newLeft = Math.max(left, Math.floor(k - _i * s / n + sd));
var newRight = Math.min(right, Math.floor(k + (n - _i) * s / n + sd));
select(array, k, newLeft, newRight);
} // partition the elements between left and right around t
var t = array[k];
var i = left;
var j = right;
swap(array, left, k);
if (comparePair(array[right], t) > 0) {
swap(array, left, right);
}
while (i < j) {
swap(array, i, j);
i++;
j--;
while (comparePair(array[i], t) < 0) {
i = i + 1;
}
while (comparePair(array[j], t) > 0) {
j = j - 1;
}
}
if (comparePair(array[left], t) === 0) {
swap(array, left, j);
} else {
j = j + 1;
swap(array, j, right);
} // Adjust left and right towards the boundaries of the subset
// containing the (k - left + 1)th smallest element.
if (j <= k) {
left = j + 1;
}
if (k <= j) {
right = j - 1;
}
}
}
function topKImpl(x, xShape, xDtype, k, sorted) {
// Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
var lastDim = xShape[xShape.length - 1];
var batch = x.length / lastDim,
size = lastDim;
var allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
var allTopKIndices = getTypedArrayFromDType('int32', batch * k);
var _loop = function _loop(b) {
var offset = b * size;
var vals = x.subarray(offset, offset + size);
var valAndInd = new Array(vals.length);
vals.forEach(function (value, index) {
return valAndInd[index] = {
value: value,
index: index
};
});
if (k < valAndInd.length) {
select(valAndInd, k);
valAndInd = valAndInd.slice(0, k);
}
if (sorted) {
valAndInd.sort(comparePair);
}
var outOffset = b * k;
var topKVals = allTopKVals.subarray(outOffset, outOffset + k);
var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
for (var i = 0; i < k; i++) {
topKVals[i] = valAndInd[i].value;
topKIndices[i] = valAndInd[i].index;
}
};
for (var b = 0; b < batch; b++) {
_loop(b);
} // Reshape back to the original input shape, except that the last
// dimension is k.
var outputShape = xShape.slice();
outputShape[outputShape.length - 1] = k;
return [buffer(outputShape, xDtype, allTopKVals), buffer(outputShape, 'int32', allTopKIndices)];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function uniqueImpl(values, axis, shape, dtype) {
// Normalize and validate axis.
var $axis = parseAxisParam(axis, shape)[0]; // Calculate the new shape that is suitable for extracting data along the
// given axis.
//
// The rank is 3.
// The size of the 1st dimension is the size of all the axes < the given axis.
// The size of the 2nd dimension is the same as the size of the given axis.
// The size of the 3rd dimension is the size of all the axes > the given axis.
//
// For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
// newShape would be: [2*3, 5, 4].
//
// Note that this is not the final output shape. This will be the shape for an
// intermediate TensorBuffer (see inputBuffer below) to allow us to extract
// values along the given axis. To demonstrate how it works, consider the
// following example:
//
// Input: a 3D tensor, with shape [1, 2, 3]
// [
// [
// [1,2,3],
// [4,5,6]
// ]
// ]
// Axis: 2 (the last axis).
// Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
//
// For this example, newShape would be: [2, 3, 1], where 2 is calculated from
// 1*2. The re-shaped data would look like:
//
// [
// [
// [1], [2], [3]
// ],
// [
// [4], [5], [6]
// ]
// ]
//
// Then, we can construct a 3-level nested loop by the following dimension
// order to extract the values along the axis (dimension1):
// i: dimension1 // 0,1,2 (newShape[1])
// m: dimension0 // 0,1 (newShape[0])
// n: dimension2 // 0 (newShape[2])
//
// m, i, n
// ---------
// Iteration 0: data at [0, 0, 0] => "1"
// Iteration 1: data at [1, 0, 0] => "4"
// We got [1,4].
// Iteration 2: data at [0, 1, 0] => "2"
// Iteration 3: data at [1, 1, 0] => "5"
// We got [2,5].
// Iteration 4: data at [0, 2, 0] => "3"
// Iteration 5: data at [1, 2, 0] => "6"
// We got [3,6].
var newShape = [1, shape[0], 1];
for (var i = 0; i < $axis; i++) {
newShape[0] *= shape[i];
}
newShape[1] = shape[$axis];
for (var _i = $axis + 1; _i < shape.length; _i++) {
newShape[2] *= shape[_i];
} // A map from unique elements (their string representations) to their values
// in "indices" (below).
var uniqueElements = {}; // The indices of each unique element in the original tensor along the given
// axis. It is 1D and has the same size as the given axis.
var indices = new Int32Array(shape[$axis]); // Create a buffer so we can easily extract value at a given location.
var inputBuffer = new TensorBuffer(newShape, dtype, values); // The indices along the given axis that have unique elements. This is a
// de-duped version of "indices" above.
var uniqueIndices = [];
var is1DTensor = newShape[0] === 1 && newShape[2] === 1;
for (var _i2 = 0; _i2 < shape[$axis]; _i2++) {
// Extract values along the axis.
var element = void 0;
if (is1DTensor) {
// Fast path for 1D tensor input.
element = values[_i2].toString();
} else {
var axisValues = [];
for (var m = 0; m < newShape[0]; m++) {
for (var n = 0; n < newShape[2]; n++) {
axisValues.push(inputBuffer.get(m, _i2, n));
}
}
element = axisValues.join(',');
} // Dedup and update various indices.
if (uniqueElements[element] !== undefined) {
indices[_i2] = uniqueElements[element];
} else {
var uniqueIndex = Object.keys(uniqueElements).length;
uniqueElements[element] = uniqueIndex;
indices[_i2] = uniqueIndex;
uniqueIndices.push(_i2);
}
} // Now we know where each of the unique elements are located along the axis
// (uniqueIndices). Extract them from input buffer and store them in the
// output buffer.
var outputTmpShape = newShape.slice();
outputTmpShape[1] = Object.keys(uniqueElements).length;
var outputBuffer = new TensorBuffer(outputTmpShape, dtype);
uniqueIndices.forEach(function (uniqueElementIndex, i) {
for (var _m = 0; _m < newShape[0]; _m++) {
for (var _n = 0; _n < newShape[2]; _n++) {
outputBuffer.set(inputBuffer.get(_m, uniqueElementIndex, _n), _m, i, _n);
}
}
}); // The output shape can be calculated from the input shape with the size of
// the given axis replaced by the number of unique elements along that axis.
var outputShape = shape.slice();
outputShape[$axis] = outputTmpShape[1];
return {
outputValues: outputBuffer.values,
outputShape: outputShape,
indices: indices
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$5 = '3.9.0';
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
registerBackend('cpu', function () {
return new MathBackendCPU();
}, 1
/* priority */
);
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var elu$3 = unaryKernelFunc(Elu, function (xi) {
return xi >= 0 ? xi : Math.exp(xi) - 1;
});
var eluConfig = {
kernelName: Elu,
backendName: 'cpu',
kernelFunc: elu$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function leakyRelu$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var alpha = attrs.alpha;
assertNotComplex([x], 'leakyRelu');
var xSize = sizeFromShape(x.shape);
var xVals = backend.data.get(x.dataId).values;
var outVals = getTypedArrayFromDType('float32', xSize);
for (var i = 0; i < xVals.length; i++) {
outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
}
return backend.makeTensorInfo(x.shape, 'float32', outVals);
}
var leakyReluConfig = {
kernelName: LeakyRelu,
backendName: 'cpu',
kernelFunc: leakyRelu$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var preluImpl = createSimpleBinaryKernelImpl(function (xValue, aValue) {
return xValue < 0 ? aValue * xValue : xValue;
});
function prelu$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x,
alpha = inputs.alpha;
assertNotComplex([x, alpha], 'prelu');
var aVals = backend.data.get(x.dataId).values;
var bVals = backend.data.get(alpha.dataId).values;
var _preluImpl = preluImpl(x.shape, alpha.shape, aVals, bVals, x.dtype),
resultData = _preluImpl[0],
resultShape = _preluImpl[1];
return backend.makeTensorInfo(resultShape, x.dtype, resultData);
}
var preluConfig = {
kernelName: Prelu,
backendName: 'cpu',
kernelFunc: prelu$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu$1 = unaryKernelFunc(Relu, function (xi) {
return Math.max(0, xi);
});
var reluConfig = {
kernelName: Relu,
backendName: 'cpu',
kernelFunc: relu$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var relu6$1 = unaryKernelFunc(Relu6, function (xi) {
return Math.min(Math.max(0, xi), 6);
});
var relu6Config = {
kernelName: Relu6,
backendName: 'cpu',
kernelFunc: relu6$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function applyActivation$1(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
if (activation === 'linear') {
return identity$1({
inputs: {
x: x
},
backend: backend
});
} else if (activation === 'relu') {
return relu$1({
inputs: {
x: x
},
backend: backend
});
} else if (activation === 'elu') {
return elu$3({
inputs: {
x: x
},
backend: backend
});
} else if (activation === 'relu6') {
return relu6$1({
inputs: {
x: x
},
backend: backend
});
} else if (activation === 'prelu') {
return prelu$2({
inputs: {
x: x,
alpha: preluActivationWeights
},
backend: backend
});
} else if (activation === 'leakyrelu') {
return leakyRelu$1({
inputs: {
x: x
},
backend: backend,
attrs: {
alpha: leakyreluAlpha
}
});
} else if (activation === 'sigmoid') {
return sigmoid$1({
inputs: {
x: x
},
backend: backend
});
}
throw new Error("Activation " + activation + " has not been implemented for the CPU backend.");
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var shape = attrs.shape;
var xSize = sizeFromShape(x.shape);
var $shape = inferFromImplicitShape(shape, xSize);
var $xSize = sizeFromShape($shape);
assert(xSize === $xSize, function () {
return "The new shape (" + $shape + ") has " + $xSize + " elements and the old " + ("shape (" + x.shape + ") has " + xSize + " elements. The new shape and old ") + "shape must have the same number of elements.";
});
backend.incRef(x.dataId);
var xData = backend.data.get(x.dataId);
if (xData.complexTensorInfos != null) {
var real = xData.complexTensorInfos.real;
var imag = xData.complexTensorInfos.imag;
real.shape = $shape;
imag.shape = $shape;
}
return {
dataId: x.dataId,
shape: $shape,
dtype: x.dtype
};
}
var reshapeConfig = {
kernelName: Reshape,
backendName: 'cpu',
kernelFunc: reshape$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchMatMul(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var a = inputs.a,
b = inputs.b;
var transposeA = attrs.transposeA,
transposeB = attrs.transposeB;
assertNotComplex([a, b], 'matMul');
var aRank = a.shape.length;
var bRank = b.shape.length;
var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
var outerDimsA = a.shape.slice(0, -2);
var outerDimsB = b.shape.slice(0, -2);
var batchDimA = sizeFromShape(outerDimsA);
var batchDimB = sizeFromShape(outerDimsB);
var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
assert(aRank >= 2 && bRank >= 2 && batchDimsCompatible, function () {
return "Error in matMul: the input batch dimensions must either be the " + "same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ").");
});
var outShapeOuterDims = batchDimA > batchDimB ? a.shape.slice(0, -2) : b.shape.slice(0, -2);
var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
assert(innerShapeA === innerShapeB, function () {
return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + a.shape + " and ") + (b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA];
var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; // The rest of the implementation is designed to operate on rank-3 tensors
var a3d = reshape$2({
inputs: {
x: a
},
backend: backend,
attrs: {
shape: a3dShape
}
});
var b3d = reshape$2({
inputs: {
x: b
},
backend: backend,
attrs: {
shape: b3dShape
}
});
var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
var leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
var rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
var batchDim = Math.max(batchDimA, batchDimB);
var a3dValues = backend.data.get(a3d.dataId).values;
var b3dValues = backend.data.get(b3d.dataId).values;
var a3dStrides = computeStrides(a3d.shape);
var b3dStrides = computeStrides(b3d.shape);
var _ref = transposeA ? [a3dStrides[0], 1, a3dStrides[1]] : [a3dStrides[0], a3dStrides[1], 1],
aBatch = _ref[0],
aOuterStep = _ref[1],
aInnerStep = _ref[2];
var _ref2 = transposeB ? [1, b3dStrides[1], b3dStrides[0]] : [b3dStrides[1], 1, b3dStrides[0]],
bInnerStep = _ref2[0],
bOuterStep = _ref2[1],
bBatch = _ref2[2];
var size = leftDim * rightDim;
var result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
var resVals = result.values;
var blockSize = backend.blockSize;
for (var bi = 0; bi < batchDim; bi++) {
for (var i0 = 0; i0 < leftDim; i0 += blockSize) {
for (var j0 = 0; j0 < rightDim; j0 += blockSize) {
for (var k0 = 0; k0 < sharedDim; k0 += blockSize) {
// for when blockSize doesn't evenly divide the input
var iBlock = Math.min(i0 + blockSize, leftDim);
var jBlock = Math.min(j0 + blockSize, rightDim);
var kBlock = Math.min(k0 + blockSize, sharedDim);
for (var i = i0; i < iBlock; i++) {
for (var j = j0; j < jBlock; j++) {
var sum = 0.0;
for (var k = k0; k < kBlock; k++) {
var batchOffsetA = Math.min(bi, batchDimA - 1) * aBatch;
var batchOffsetB = Math.min(bi, batchDimB - 1) * bBatch;
var aVal = a3dValues[batchOffsetA + i * aOuterStep + k * aInnerStep];
var bVal = b3dValues[k * bInnerStep + j * bOuterStep + batchOffsetB];
sum += aVal * bVal;
}
resVals[bi * size + (i * rightDim + j)] += sum;
}
}
}
}
}
}
backend.disposeIntermediateTensorInfo(a3d);
backend.disposeIntermediateTensorInfo(b3d); // set correct shape on output.
return backend.makeTensorInfo(outShape, result.dtype, result.values);
}
var batchMatMulConfig = {
kernelName: BatchMatMul,
backendName: 'cpu',
kernelFunc: batchMatMul
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function _fusedMatMul(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var a = inputs.a,
b = inputs.b,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var transposeA = attrs.transposeA,
transposeB = attrs.transposeB,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
var current;
var addRes;
var activationRes;
var intermediates = [];
var matMulRes = batchMatMul({
inputs: {
a: a,
b: b
},
attrs: {
transposeA: transposeA,
transposeB: transposeB
},
backend: backend
});
current = matMulRes;
if (bias) {
addRes = add$4({
inputs: {
a: current,
b: bias
},
backend: backend
});
intermediates.push(current);
current = addRes;
}
if (activation) {
activationRes = applyActivation$1(backend, current, activation, preluActivationWeights, leakyreluAlpha);
intermediates.push(current);
current = activationRes;
}
for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
var i = _intermediates[_i];
backend.disposeIntermediateTensorInfo(i);
}
return current;
}
var _fusedMatMulConfig = {
kernelName: _FusedMatMul,
backendName: 'cpu',
kernelFunc: _fusedMatMul
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acos$1 = unaryKernelFunc(Acos, function (xi) {
return Math.acos(xi);
});
var acosConfig = {
kernelName: Acos,
backendName: 'cpu',
kernelFunc: acos$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var acosh$1 = unaryKernelFunc(Acosh, function (xi) {
return Math.acosh(xi);
});
var acoshConfig = {
kernelName: Acosh,
backendName: 'cpu',
kernelFunc: acosh$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function addN$1(args) {
var inputs = args.inputs,
backend = args.backend;
var tensors = inputs;
assertNotComplex(inputs, 'addN');
var vals = tensors.map(function (t) {
return backend.data.get(t.dataId).values;
});
var outBuf = buffer(tensors[0].shape, tensors[0].dtype);
var outVals = outBuf.values;
for (var i = 0; i < tensors.length; i++) {
var currVals = vals[i];
for (var j = 0; j < outVals.length; j++) {
outVals[j] += currVals[j];
}
}
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
var addNConfig = {
kernelName: AddN,
backendName: 'cpu',
kernelFunc: addN$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function all$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
assertNotComplex(x, 'all');
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
if (permutedAxes != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('all', axes, $x.shape.length);
var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var reduceSize = sizeFromShape(reduceShape);
var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
var aVals = backend.data.get($x.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var _all = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
_all = _all && value;
}
vals[i] = _all;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
var expandedShape = expandShapeToKeepDim(outShape, origAxes);
var reshapedResult = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: expandedShape
}
});
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
var allConfig = {
kernelName: All,
backendName: 'cpu',
kernelFunc: all$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function any$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
assertNotComplex(x, 'any');
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
if (permutedAxes != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('any', axes, $x.shape.length);
var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var reduceSize = sizeFromShape(reduceShape);
var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
var aVals = backend.data.get($x.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var anyVal = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
anyVal = anyVal || value;
}
vals[i] = anyVal;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
var expandedShape = expandShapeToKeepDim(outShape, origAxes);
var reshapedResult = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: expandedShape
}
});
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
var anyConfig = {
kernelName: Any,
backendName: 'cpu',
kernelFunc: any$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMax$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis;
assertNotComplex(x, 'argMax');
var axes = parseAxisParam(axis, x.shape);
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
var intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
axes = [axes[0]];
assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outSize = sizeFromShape(outShape);
var vals = makeZerosTypedArray(outSize, 'int32');
var reduceSize = sizeFromShape(reduceShape);
var aVals = backend.data.get($x.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var max = aVals[offset];
var maxIndex = 0;
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value > max) {
max = value;
maxIndex = j;
}
}
vals[i] = maxIndex;
}
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return backend.makeTensorInfo(outShape, 'int32', vals);
}
var argMaxConfig = {
kernelName: ArgMax,
backendName: 'cpu',
kernelFunc: argMax$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMin$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis;
assertNotComplex(x, 'argMin');
var axes = parseAxisParam(axis, x.shape);
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
var intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
axes = [axes[0]];
assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outSize = sizeFromShape(outShape);
var vals = makeZerosTypedArray(outSize, 'int32');
var reduceSize = sizeFromShape(reduceShape);
var aVals = backend.data.get($x.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var min = aVals[offset];
var minIndex = 0;
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (value < min) {
min = value;
minIndex = j;
}
}
vals[i] = minIndex;
}
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return backend.makeTensorInfo(outShape, 'int32', vals);
}
var argMinConfig = {
kernelName: ArgMin,
backendName: 'cpu',
kernelFunc: argMin$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asin$1 = unaryKernelFunc(Asin, function (xi) {
return Math.asin(xi);
});
var asinConfig = {
kernelName: Asin,
backendName: 'cpu',
kernelFunc: asin$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var asinh$2 = unaryKernelFunc(Asinh, function (xi) {
return Math.asinh(xi);
});
var asinhConfig = {
kernelName: Asinh,
backendName: 'cpu',
kernelFunc: asinh$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan$1 = unaryKernelFunc(Atan, function (xi) {
return Math.atan(xi);
});
var atanConfig = {
kernelName: Atan,
backendName: 'cpu',
kernelFunc: atan$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atan2Impl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
return Math.atan2(aValue, bValue);
});
var atan2$1 = binaryKernelFunc(Atan2, atan2Impl);
var atan2Config = {
kernelName: Atan2,
backendName: 'cpu',
kernelFunc: atan2$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var atanh$1 = unaryKernelFunc(Atanh, function (xi) {
return Math.atanh(xi);
});
var atanhConfig = {
kernelName: Atanh,
backendName: 'cpu',
kernelFunc: atanh$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pool$1(xValues, xShape, dtype, strides, convInfo, poolType) {
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var initialValue = poolType === 'max' ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
var output = buffer(convInfo.outShape, dtype);
var outputVals = output.values;
var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
var outputColStrides = convInfo.outShape[3];
for (var b = 0; b < convInfo.batchSize; ++b) {
var outputBatchOffset = b * outputBatchStrides;
var inputBatchOffset = b * strides[0];
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = Math.max(0, xRCorner);
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
var outputRowOffset = outputBatchOffset + yR * outputRowStrides;
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = Math.max(0, xCCorner);
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var minMaxValue = initialValue;
var avgValue = 0;
var count = 0;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var xROffset = inputBatchOffset + xR * strides[1];
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var xCOffset = xROffset + xC * strides[2];
var pixel = xValues[xCOffset + d];
if (poolType === 'max' && pixel > minMaxValue) {
minMaxValue = pixel;
} else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
var outputOffset = outputRowOffset + yC * outputColStrides + d;
outputVals[outputOffset] = poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
return output;
}
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
var maxPositions = buffer(convInfo.outShape, 'int32');
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var xBuf = buffer(xShape, dtype, xValues);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var xRCorner = yR * strideHeight - padTop;
var xRMin = xRCorner;
while (xRMin < 0) {
xRMin += dilationHeight;
} // const xRMin = Math.max(0, xRCorner);
var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var xCCorner = yC * strideWidth - padLeft;
var xCMin = xCCorner;
while (xCMin < 0) {
xCMin += dilationWidth;
}
var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
var maxValue = Number.NEGATIVE_INFINITY;
var maxPosition = -1;
for (var xR = xRMin; xR < xRMax; xR += dilationHeight) {
var wR = xR - xRCorner;
for (var xC = xCMin; xC < xCMax; xC += dilationWidth) {
var wC = xC - xCCorner;
var pixel = xBuf.get(b, xR, xC, d);
if (pixel > maxValue) {
maxValue = pixel;
if (flattenPositions) {
maxPosition = includeBatchInIndex ? ((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) * convInfo.inChannels + d : (xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
} else {
maxPosition = wR * effectiveFilterWidth + wC;
}
}
}
}
maxPositions.set(maxPosition, b, yR, yC, d);
}
}
}
}
return maxPositions;
}
function pool3d$1(xValues, xShape, dtype, strides, convInfo, poolType) {
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var initialValue = poolType === 'max' ? Number.NEGATIVE_INFINITY : Number.POSITIVE_INFINITY;
var output = buffer(convInfo.outShape, dtype);
var outputVals = output.values;
var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
var outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
var outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
var outputColStrides = convInfo.outShape[4];
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
var outputBatchOffset = batch * outputBatchStrides;
var inputBatchOffset = batch * strides[0];
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
var xDepthCorner = yDepth * strideDepth - padFront;
var xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
var outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
var xRowCorner = yRow * strideHeight - padTop;
var xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
var outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
var xColCorner = yCol * strideWidth - padLeft;
var xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); // Shader code begins
var outputColOffset = outputRowOffset + yCol * outputColStrides;
var minMaxValue = initialValue;
var avgValue = 0;
var count = 0;
for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
var xDepthOffset = inputBatchOffset + xDepth * strides[1];
for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
var xRowOffset = xDepthOffset + xRow * strides[2];
for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
var xColOffset = xRowOffset + xCol * strides[3];
var pixel = xValues[xColOffset + channel];
if (poolType === 'max' && pixel > minMaxValue) {
minMaxValue = pixel;
} else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
var outputOffset = outputColOffset + channel;
outputVals[outputOffset] = poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
}
return output;
}
function maxPool3dPositions(xBuf, convInfo) {
var maxPositions = buffer(convInfo.outShape, 'int32');
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
var xDepthCorner = yDepth * strideDepth - padFront;
var xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) {
var xRowCorner = yRow * strideHeight - padTop;
var xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) {
var xColCorner = yCol * strideWidth - padLeft;
var xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); // Shader code begins
var maxValue = Number.NEGATIVE_INFINITY;
var maxPosition = -1;
for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
var wDepth = xDepth - xDepthCorner;
for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
var wRow = xRow - xRowCorner;
for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
var wCol = xCol - xColCorner;
var pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
if (pixel >= maxValue) {
maxValue = pixel;
maxPosition = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterHeight + wCol;
}
}
}
}
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
}
}
}
}
}
return maxPositions;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, 'avgPool');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in avgPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity$1({
inputs: {
x: x
},
backend: backend
});
} else {
var xValues = backend.data.get(x.dataId).values;
var _strides = computeStrides(x.shape);
var buffer = pool$1(xValues, x.shape, x.dtype, _strides, convInfo, 'avg');
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
var avgPoolConfig = {
kernelName: AvgPool,
backendName: 'cpu',
kernelFunc: avgPool$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
dataFormat = attrs.dataFormat;
assertNotComplex(x, 'avgPool3d');
var convInfo = computePool3DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode, dataFormat);
var xValues = backend.data.get(x.dataId).values;
var outBuf = pool3d$1(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
}
var avgPool3DConfig = {
kernelName: AvgPool3D,
backendName: 'cpu',
kernelFunc: avgPool3D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3DGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([dy, input], 'avgPool3DGrad');
var convInfo = computePool3DInfo(input.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode);
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = buffer(input.shape, 'float32');
var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
var dyBuf = backend.bufferSync(dy);
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
// Shader code begins.
var dyDepthCorner = dxDepth - padFront;
var dyRowCorner = dxRow - padTop;
var dyColCorner = dxCol - padLeft;
var dotProd = 0;
for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
var dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
continue;
}
for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
var dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
continue;
}
var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel;
}
}
}
dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var avgPool3DGradConfig$1 = {
kernelName: AvgPool3DGrad,
backendName: 'cpu',
kernelFunc: avgPool3DGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPoolGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var x = input;
assertNotComplex([dy, input], 'avgPoolGrad');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad;
var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad);
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = buffer(x.shape, 'float32');
var avgMultiplier = 1 / (filterHeight * filterWidth);
var dyData = backend.data.get(dy.dataId).values;
var dyBuf = buffer(dy.shape, 'float32', dyData);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
// Shader code begins.
var dyRCorner = dxR - padTop;
var dyCCorner = dxC - padLeft;
var dotProd = 0;
for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
var dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
continue;
}
for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
var dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
continue;
}
var pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel;
}
}
dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var avgPoolGradConfig$1 = {
kernelName: AvgPoolGrad,
backendName: 'cpu',
kernelFunc: avgPoolGrad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchNorm$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
scale = inputs.scale,
offset = inputs.offset,
mean = inputs.mean,
variance = inputs.variance;
assert(mean.shape.length === variance.shape.length, function () {
return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
});
assert(offset == null || mean.shape.length === offset.shape.length, function () {
return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
});
assert(scale == null || mean.shape.length === scale.shape.length, function () {
return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
});
assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
var varianceEpsilon = attrs.varianceEpsilon;
if (varianceEpsilon == null) {
varianceEpsilon = 0.001;
}
var xVals = backend.data.get(x.dataId).values;
var mVals = backend.data.get(mean.dataId).values;
var varVals = backend.data.get(variance.dataId).values;
var sVals = scale ? backend.data.get(scale.dataId).values : new Float32Array([1]);
var offVals = offset ? backend.data.get(offset.dataId).values : new Float32Array([0]);
var outVals = new Float32Array(xVals.length);
var offValsLength = offVals.length;
var sValsLength = sVals.length;
var varValsLength = varVals.length;
var mValsLength = mVals.length;
var offi = 0;
var mi = 0;
var si = 0;
var vi = 0;
for (var i = 0; i < xVals.length; ++i) {
outVals[i] = offVals[offi++] + (xVals[i] - mVals[mi++]) * sVals[si++] / Math.sqrt(varVals[vi++] + varianceEpsilon);
if (offi >= offValsLength) {
offi = 0;
}
if (mi >= mValsLength) {
mi = 0;
}
if (si >= sValsLength) {
si = 0;
}
if (vi >= varValsLength) {
vi = 0;
}
}
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
}
var batchNormConfig = {
kernelName: FusedBatchNorm,
backendName: 'cpu',
kernelFunc: batchNorm$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchToSpaceND$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockShape = attrs.blockShape,
crops = attrs.crops;
assertNotComplex([x], 'batchToSpaceND');
var prod = blockShape.reduce(function (a, b) {
return a * b;
});
var reshaped = getReshaped(x.shape, blockShape, prod);
var permuted = getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
var xReshaped = reshape$2({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: reshaped
}
});
var xTransposed = transpose$1({
inputs: {
x: xReshaped
},
backend: backend,
attrs: {
perm: permuted
}
});
var xTransposedReshaped = reshape$2({
inputs: {
x: xTransposed
},
backend: backend,
attrs: {
shape: reshapedPermuted
}
});
var result = slice$3({
inputs: {
x: xTransposedReshaped
},
backend: backend,
attrs: {
begin: sliceBeginCoords,
size: sliceSize
}
});
backend.disposeIntermediateTensorInfo(xReshaped);
backend.disposeIntermediateTensorInfo(xTransposed);
backend.disposeIntermediateTensorInfo(xTransposedReshaped);
return result;
}
var batchToSpaceNDConfig = {
kernelName: BatchToSpaceND,
backendName: 'cpu',
kernelFunc: batchToSpaceND$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function bincount$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
weights = inputs.weights;
var size = attrs.size;
var xVals = backend.data.get(x.dataId).values;
var weightsVals = backend.data.get(weights.dataId).values;
var outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
var bincountConfig = {
kernelName: Bincount,
backendName: 'cpu',
kernelFunc: bincount$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function broadcastArgs$1(args) {
var inputs = args.inputs,
backend = args.backend;
var s0 = inputs.s0,
s1 = inputs.s1;
var s0Vals = backend.data.get(s0.dataId).values;
var s1Vals = backend.data.get(s1.dataId).values;
var broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
}
var broadcastArgsConfig = {
kernelName: BroadcastArgs,
backendName: 'cpu',
kernelFunc: broadcastArgs$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var clip = unaryKernelFunc(ClipByValue, function (xi, attrs) {
var clipAttrs = attrs;
if (xi > clipAttrs.clipValueMax) {
return clipAttrs.clipValueMax;
}
return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
});
var clipConfig = {
kernelName: ClipByValue,
backendName: 'cpu',
kernelFunc: clip
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var complexAbs = function complexAbs(args) {
var x = args.inputs.x;
var cpuBackend = args.backend;
var resultValues = new Float32Array(sizeFromShape(x.shape));
var complexVals = cpuBackend.data.get(x.dataId);
var real = complexVals.complexTensorInfos.real;
var imag = complexVals.complexTensorInfos.imag;
var realVals = cpuBackend.data.get(real.dataId).values;
var imagVals = cpuBackend.data.get(imag.dataId).values;
for (var i = 0; i < realVals.length; i++) {
var _real = realVals[i];
var _imag = imagVals[i];
resultValues[i] = Math.hypot(_real, _imag);
}
return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
};
var complexAbsConfig = {
kernelName: ComplexAbs,
backendName: 'cpu',
kernelFunc: complexAbs
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function imag$1(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var imag = backend.data.get(input.dataId).complexTensorInfos.imag;
var imagVal = backend.data.get(imag.dataId).values; // When complex tensor is disposed, its underlying parts will be disposed too.
// Make new tensor out of the imag value of the complex. This makes sure the
// value is still accessible even if complex tensor is disposed.
return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
}
var imagConfig = {
kernelName: Imag,
backendName: 'cpu',
kernelFunc: imag$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concat$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var axis = attrs.axis;
var $axis = parseAxisParam(axis, inputs[0].shape)[0];
var outShape = computeOutShape$1(inputs.map(function (t) {
return t.shape;
}), $axis);
if (sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
} // Keep only non-empty tensors (ignore tensors with 0 in their shape).
var $inputs = inputs.filter(function (t) {
return sizeFromShape(t.shape) > 0;
});
if ($inputs.length === 1) {
return identity$1({
inputs: {
x: $inputs[0]
},
backend: backend
});
}
var shapes = $inputs.map(function (t) {
return t.shape;
});
assertParamsConsistent(shapes, $axis);
if ($inputs[0].dtype === 'complex64') {
var reals = $inputs.map(function (t) {
return real$1({
inputs: {
input: t
},
backend: backend
});
});
var imags = $inputs.map(function (t) {
return imag$1({
inputs: {
input: t
},
backend: backend
});
});
var realConcated = concat$1({
inputs: reals,
backend: backend,
attrs: {
axis: $axis
}
});
var imagConcated = concat$1({
inputs: imags,
backend: backend,
attrs: {
axis: $axis
}
});
var result = complex$1({
inputs: {
real: realConcated,
imag: imagConcated
},
backend: backend
});
reals.forEach(function (r) {
return backend.disposeIntermediateTensorInfo(r);
});
imags.forEach(function (i) {
return backend.disposeIntermediateTensorInfo(i);
});
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return result;
} // Any concat of n-dimensional tensors across any axis can be reduced to
// a concatenation of two-dimensional tensors across the axis 1 by first
// partitioning the axes of the original tensors into those less than the
// axis to be concatenated and the rest. Then reshape the tensors
// into a two-dimensional tensor by collapsing these two sets of axes and
// concatenate the resulting matrices across the axis 1, finally reshaping
// the result to have the proper shape.
var inputs2D = $inputs.map(function (t) {
var innerSize = sizeFromShape(t.shape.slice($axis));
var shape = [-1, innerSize];
return reshape$2({
inputs: {
x: t
},
backend: backend,
attrs: {
shape: shape
}
});
});
var inputsValShapes = inputs2D.map(function (t) {
return {
vals: backend.data.get(t.dataId).values,
shape: t.shape
};
}); // Concats 2d tensors along axis=1.
outShape = computeOutShape$1(inputs2D.map(function (t) {
return t.shape;
}), 1
/* axis */
);
var simplyConcat = inputs2D[0].shape[0] === 1;
var outVals = concatImpl(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
var finalOutShape = computeOutShape$1($inputs.map(function (t) {
return t.shape;
}), $axis);
var outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
inputs2D.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return outInfo;
}
var concatConfig = {
kernelName: Concat,
backendName: 'cpu',
kernelFunc: concat$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([x, filter], 'conv2d');
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false
/* depthwise */
, $dataFormat);
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var padLeft = convInfo.padInfo.left;
var padTop = convInfo.padInfo.top;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
var y = new TensorBuffer(convInfo.outShape, x.dtype);
var xStrides = computeStrides(x.shape);
var filterStrides = computeStrides(filter.shape);
var xBatchStride = xStrides[0];
var xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
var xColStride = isChannelsLast ? xStrides[2] : 1;
var xChannelStride = isChannelsLast ? 1 : xStrides[1];
var yBatchStride = y.strides[0];
var yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
var yColStride = isChannelsLast ? y.strides[2] : 1;
var yChannelStride = isChannelsLast ? 1 : y.strides[1];
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xBatchStride;
var yOffset1 = b * yBatchStride;
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset2 = yOffset1 + yR * yRowStride;
var xRCorner = yR * convInfo.strideHeight - padTop;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset1 = wR * filterStrides[0];
var xOffset2 = xOffset1 + xR * xRowStride;
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset3 = yOffset2 + yC * yColStride;
var xCCorner = yC * convInfo.strideWidth - padLeft;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset2 = wOffset1 + wC * filterStrides[1];
var xOffset3 = xOffset2 + xC * xColStride;
var wOffset3 = wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset3 + d1 * xChannelStride];
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset3 + d2 * yChannelStride] += xVal * wVals[wOffset3 + d2];
}
wOffset3 += convInfo.outChannels;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, yVals);
}
var conv2DConfig = {
kernelName: Conv2D,
backendName: 'cpu',
kernelFunc: conv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropFilter$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode,
filterShape = attrs.filterShape;
assertNotComplex([x, dy], 'conv2dBackpropFilter');
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1
/* dilations */
, pad, dimRoundingMode, false
/* depthwise */
, $dataFormat);
var strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
var dW = new TensorBuffer(convInfo.filterShape, 'float32');
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
var xVals = backend.data.get(x.dataId).values;
var dyVals = backend.data.get(dy.dataId).values;
var xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
var dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
if (isChannelsLast) {
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
} else {
dotProd += xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC);
}
}
}
}
dW.set(dotProd, wR, wC, d1, d2);
}
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
var conv2DBackpropFilterConfig = {
kernelName: Conv2DBackpropFilter,
backendName: 'cpu',
kernelFunc: conv2DBackpropFilter$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropInput$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var inputShape = attrs.inputShape,
strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([dy, filter], 'conv2dBackpropInput');
var filterStrides = computeStrides(filter.shape);
var dyStrides = computeStrides(dy.shape);
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1
/* dilations */
, pad, dimRoundingMode, false, $dataFormat);
var dx = new TensorBuffer(convInfo.inShape, 'float32');
var dxValues = dx.values;
var dyValues = backend.data.get(dy.dataId).values;
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0],
fltS1 = filterStrides[1],
fltS2 = filterStrides[2];
var batchSize = convInfo.batchSize,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
inChannels = convInfo.inChannels,
inHeight = convInfo.inHeight,
inWidth = convInfo.inWidth,
outChannels = convInfo.outChannels,
outHeight = convInfo.outHeight,
outWidth = convInfo.outWidth,
strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth;
$dataFormat = convInfo.dataFormat;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
var isChannelsLast = $dataFormat === 'channelsLast';
var xBatchStride = dx.strides[0];
var xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
var xColStride = isChannelsLast ? dx.strides[2] : 1;
var xChannelStride = isChannelsLast ? 1 : dx.strides[1];
var yBatchStride = dyStrides[0];
var yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
var yColStride = isChannelsLast ? dyStrides[2] : 1;
var yChannelStride = isChannelsLast ? 1 : dyStrides[1];
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (var d2 = 0; d2 < outChannels; ++d2) {
var pixel = dyValues[dyOffset + yChannelStride * d2];
var weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
var dxOffset = xBatchStride * b + xRowStride * xR + xColStride * xC + xChannelStride * d1;
dxValues[dxOffset] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var conv2DBackpropInputConfig = {
kernelName: Conv2DBackpropInput,
backendName: 'cpu',
kernelFunc: conv2DBackpropInput$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
assertNotComplex([x, filter], 'conv3d');
var convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
var filterDepth = convInfo.filterDepth,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
dilationDepth = convInfo.dilationDepth,
dilationHeight = convInfo.dilationHeight,
dilationWidth = convInfo.dilationWidth,
padInfo = convInfo.padInfo;
var padFront = padInfo.front;
var padLeft = padInfo.left;
var padTop = padInfo.top;
var y = new TensorBuffer(convInfo.outShape, x.dtype);
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
var xStrides = computeStrides(x.shape);
var filterStrides = computeStrides(filter.shape);
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xStrides[0];
var yOffset1 = b * y.strides[0];
for (var yF = 0; yF < convInfo.outDepth; ++yF) {
var yOffset2 = yOffset1 + yF * y.strides[1];
var xFCorner = yF * convInfo.strideDepth - padFront;
for (var wF = 0; wF < filterDepth; ++wF) {
var xF = xFCorner + wF * dilationDepth;
if (xF < 0 || xF >= convInfo.inDepth) {
continue;
}
var wOffset1 = wF * filterStrides[0];
var xOffset2 = xOffset1 + xF * xStrides[1];
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset3 = yOffset2 + yR * y.strides[2];
var xRCorner = yR * convInfo.strideHeight - padTop;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset2 = wOffset1 + wR * filterStrides[1];
var xOffset3 = xOffset2 + xR * xStrides[2];
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset4 = yOffset3 + yC * convInfo.outChannels;
var xCCorner = yC * convInfo.strideWidth - padLeft;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset3 = wOffset2 + wC * filterStrides[2];
var xOffset4 = xOffset3 + xC * convInfo.inChannels;
var wOffset4 = wOffset3;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset4 + d1];
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
}
wOffset4 += convInfo.outChannels;
}
}
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
var conv3DConfig = {
kernelName: Conv3D,
backendName: 'cpu',
kernelFunc: conv3D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropFilterV2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
filterShape = attrs.filterShape;
assertNotComplex([x, dy], 'conv3dBackpropFilterV2');
var xStrides = computeStrides(x.shape);
var dyStrides = computeStrides(dy.shape);
var convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1
/* dilations */
, pad);
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var dw = new TensorBuffer(convInfo.filterShape, 'float32');
var dwValues = dw.values;
var _dw$strides = dw.strides,
dwS0 = _dw$strides[0],
dwS1 = _dw$strides[1],
dwS2 = _dw$strides[2],
dwS3 = _dw$strides[3];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0],
dyS1 = dyStrides[1],
dyS2 = dyStrides[2],
dyS3 = dyStrides[3];
var xValues = backend.data.get(x.dataId).values;
var xS0 = xStrides[0],
xS1 = xStrides[1],
xS2 = xStrides[2],
xS3 = xStrides[3];
var frontPad = convInfo.padInfo.front;
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
for (var wF = 0; wF < filterDepth; ++wF) {
var yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
var yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
var wOffset1 = wF * dwS0;
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
var wOffset2 = wR * dwS1 + wOffset1;
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
var wOffset3 = wC * dwS2 + wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var wOffset4 = d1 * dwS3 + wOffset3;
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xS0;
var yOffset1 = b * dyS0;
for (var yF = yFMin; yF < yFMax; ++yF) {
var xF = wF + yF * strideDepth - frontPad;
var xOffset2 = xF * xS1 + xOffset1;
var yOffset2 = yF * dyS1 + yOffset1;
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
var xOffset3 = xR * xS2 + xOffset2;
var yOffset3 = yR * dyS2 + yOffset2;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
var xOffset4 = xC * xS3 + xOffset3;
var yOffset4 = yC * dyS3 + yOffset3;
dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
}
}
}
}
dwValues[wOffset4 + d2] = dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
}
var conv3DBackpropFilterV2Config = {
kernelName: Conv3DBackpropFilterV2,
backendName: 'cpu',
kernelFunc: conv3DBackpropFilterV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropInputV2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var pad = attrs.pad,
strides = attrs.strides,
inputShape = attrs.inputShape;
assertNotComplex([dy], 'conv3dBackpropInputV2');
var dyStrides = computeStrides(dy.shape);
var filterStrides = computeStrides(filter.shape);
var convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1
/* dilations */
, pad);
var dx = new TensorBuffer(convInfo.inShape, 'float32');
var dxValues = dx.values;
var _dx$strides = dx.strides,
dxS0 = _dx$strides[0],
dxS1 = _dx$strides[1],
dxS2 = _dx$strides[2],
dxS3 = _dx$strides[3];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0],
dyS1 = dyStrides[1],
dyS2 = dyStrides[2],
dyS3 = dyStrides[3];
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0],
fltS1 = filterStrides[1],
fltS2 = filterStrides[2],
fltS3 = filterStrides[3];
var batchSize = convInfo.batchSize,
filterDepth = convInfo.filterDepth,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
inChannels = convInfo.inChannels,
inDepth = convInfo.inDepth,
inHeight = convInfo.inHeight,
inWidth = convInfo.inWidth,
outChannels = convInfo.outChannels,
outDepth = convInfo.outDepth,
outHeight = convInfo.outHeight,
outWidth = convInfo.outWidth,
strideDepth = convInfo.strideDepth,
strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth;
var frontPad = filterDepth - 1 - convInfo.padInfo.front;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
// Frames of depth
for (var xF = 0; xF < inDepth; ++xF) {
var xFCorner = xF - frontPad;
var xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
var yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth); // Rows as per standard 2d matrix notation
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); // Columns as per standard 2d matrix notation
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yF = xFMin; yF < yFMax; ++yF) {
var wF = yF * strideDepth - xFCorner;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
var fltOffset = fltS0 * (filterDepth - 1 - wF) + fltS1 * (filterHeight - 1 - wR) + fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
for (var d2 = 0; d2 < outChannels; ++d2) {
var pixel = dyValues[dyOffset + d2];
var weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
}
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] = dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var conv3DBackpropInputV2Config = {
kernelName: Conv3DBackpropInputV2,
backendName: 'cpu',
kernelFunc: conv3DBackpropInputV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cos$1 = unaryKernelFunc(Cos, function (xi) {
return Math.cos(xi);
});
var cosConfig = {
kernelName: Cos,
backendName: 'cpu',
kernelFunc: cos$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cosh$1 = unaryKernelFunc(Cosh, function (xi) {
return Math.cosh(xi);
});
var coshConfig = {
kernelName: Cosh,
backendName: 'cpu',
kernelFunc: cosh$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cropAndResize$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var image = inputs.image,
boxes = inputs.boxes,
boxInd = inputs.boxInd;
var cropSize = attrs.cropSize,
method = attrs.method,
extrapolationValue = attrs.extrapolationValue;
var _image$shape = image.shape,
batch = _image$shape[0],
imageHeight = _image$shape[1],
imageWidth = _image$shape[2],
numChannels = _image$shape[3];
var numBoxes = boxes.shape[0];
var cropHeight = cropSize[0],
cropWidth = cropSize[1];
var output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
var boxVals = backend.data.get(boxes.dataId).values;
var boxIndVals = backend.data.get(boxInd.dataId).values;
var imageVals = backend.data.get(image.dataId).values;
var inStride = computeStrides(image.shape); // to calculate flat indexes into image
var outStride = computeStrides(output.shape); // to calculate flat indexes into output
// Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc
for (var b = 0; b < numBoxes; b++) {
var startInd = b * 4;
var y1 = boxVals[startInd];
var x1 = boxVals[startInd + 1];
var y2 = boxVals[startInd + 2];
var x2 = boxVals[startInd + 3];
var bInd = boxIndVals[b];
if (bInd >= batch) {
continue;
}
var heightScale = cropHeight > 1 ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
var widthScale = cropWidth > 1 ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
for (var y = 0; y < cropHeight; y++) {
var yInd = cropHeight > 1 ? y1 * (imageHeight - 1) + y * heightScale : 0.5 * (y1 + y2) * (imageHeight - 1);
if (yInd < 0 || yInd > imageHeight - 1) {
for (var x = 0; x < cropWidth; x++) {
for (var c = 0; c < numChannels; c++) {
var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
}
continue;
}
if (method === 'bilinear') {
var topInd = Math.floor(yInd);
var bottomInd = Math.ceil(yInd);
var yLerp = yInd - topInd;
for (var _x = 0; _x < cropWidth; _x++) {
var xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + _x * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
if (xInd < 0 || xInd > imageWidth - 1) {
for (var _c = 0; _c < numChannels; _c++) {
var _ind = _c + _x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[_ind] = extrapolationValue;
}
continue;
}
var leftInd = Math.floor(xInd);
var rightInd = Math.ceil(xInd);
var xLerp = xInd - leftInd;
for (var _c2 = 0; _c2 < numChannels; _c2++) {
var _ind2 = _c2 + leftInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
var topLeft = imageVals[_ind2];
_ind2 = _c2 + rightInd * inStride[2] + topInd * inStride[1] + bInd * inStride[0];
var topRight = imageVals[_ind2];
_ind2 = _c2 + leftInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
var bottomLeft = imageVals[_ind2];
_ind2 = _c2 + rightInd * inStride[2] + bottomInd * inStride[1] + bInd * inStride[0];
var bottomRight = imageVals[_ind2];
var top = topLeft + (topRight - topLeft) * xLerp;
var bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
_ind2 = _c2 + _x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[_ind2] = top + (bottom - top) * yLerp;
}
}
} else {
// method == "nearest"
for (var _x2 = 0; _x2 < cropWidth; ++_x2) {
var _xInd = cropWidth > 1 ? x1 * (imageWidth - 1) + _x2 * widthScale : 0.5 * (x1 + x2) * (imageWidth - 1);
if (_xInd < 0 || _xInd > imageWidth - 1) {
for (var _c3 = 0; _c3 < numChannels; _c3++) {
var _ind3 = _c3 + _x2 * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[_ind3] = extrapolationValue;
}
continue;
}
var closestX = Math.round(_xInd);
var closestY = Math.round(yInd);
for (var _c4 = 0; _c4 < numChannels; _c4++) {
var inInd = _c4 + closestX * inStride[2] + closestY * inStride[1] + bInd * inStride[0];
var outInd = _c4 + _x2 * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[outInd] = imageVals[inInd];
}
}
}
}
}
return backend.makeTensorInfo(output.shape, output.dtype, output.values);
}
var cropAndResizeConfig = {
kernelName: CropAndResize,
backendName: 'cpu',
kernelFunc: cropAndResize$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cumsum$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
exclusive = attrs.exclusive,
reverse = attrs.reverse;
assertNotComplex(x, 'cumsum');
var permutation = getAxesPermutation([axis], x.shape.length);
var $x = x;
if (permutation != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutation
}
});
}
var permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
if (permutedAxis !== $x.shape.length - 1) {
throw new Error("backend.cumsum in CPU expects an inner-most " + ("axis=" + ($x.shape.length - 1) + " but got axis=" + permutedAxis));
}
var resultDtype = upcastType($x.dtype, 'int32');
var vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
var aVals = backend.data.get($x.dataId).values;
var finalDim = $x.shape[$x.shape.length - 1];
var indexAdjuster = reverse ? function (i, j) {
return i + finalDim - j - 1;
} : function (i, j) {
return i + j;
};
for (var i = 0; i < aVals.length; i += finalDim) {
for (var j = 0; j < finalDim; j++) {
var idx = indexAdjuster(i, j);
if (j === 0) {
vals[idx] = exclusive ? 0 : aVals[idx];
} else {
var prevIdx = indexAdjuster(i, j - 1);
vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : aVals[idx] + vals[prevIdx];
}
}
}
var result = backend.makeTensorInfo($x.shape, resultDtype, vals);
if (permutation != null) {
var reversePermutation = getUndoAxesPermutation(permutation);
var reverseTransposedResult = transpose$1({
inputs: {
x: result
},
backend: backend,
attrs: {
perm: reversePermutation
}
});
backend.disposeIntermediateTensorInfo(result);
backend.disposeIntermediateTensorInfo($x);
return reverseTransposedResult;
}
return result;
}
var cumsumConfig = {
kernelName: Cumsum,
backendName: 'cpu',
kernelFunc: cumsum$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function denseBincount$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
weights = inputs.weights;
var size = attrs.size,
binaryOutput = attrs.binaryOutput;
if (x.shape.length === 1) {
var xVals = backend.data.get(x.dataId).values;
var weightsVals = backend.data.get(weights.dataId).values;
var outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
} else if (x.shape.length === 2) {
var xBuf = backend.bufferSync(x);
var weightsBuf = backend.bufferSync(weights);
var outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
}
throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" + (x.shape.length + "."));
}
var denseBincountConfig = {
kernelName: DenseBincount,
backendName: 'cpu',
kernelFunc: denseBincount$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthToSpace$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockSize = attrs.blockSize,
dataFormat = attrs.dataFormat;
assert(dataFormat === 'NHWC', function () {
return "Only NHWC dataFormat supported on CPU for depthToSpace. Got " + dataFormat;
});
assert(blockSize > 1, function () {
return "blockSize should be > 1 for depthToSpace, but was: " + blockSize;
});
var batchSize = x.shape[0];
var inputHeight = x.shape[1];
var inputWidth = x.shape[2];
var inputDepth = x.shape[3];
var outputHeight = inputHeight * blockSize;
var outputWidth = inputWidth * blockSize;
var outputDepth = inputDepth / (blockSize * blockSize);
var xValues = backend.data.get(x.dataId).values;
var result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
var outputIdx = 0;
for (var b = 0; b < batchSize; ++b) {
for (var h = 0; h < outputHeight; ++h) {
var inH = Math.floor(h / blockSize);
var offsetH = h % blockSize;
for (var w = 0; w < outputWidth; ++w) {
var inW = Math.floor(w / blockSize);
var offsetW = w % blockSize;
var offsetD = (offsetH * blockSize + offsetW) * outputDepth;
for (var d = 0; d < outputDepth; ++d) {
var inD = d + offsetD;
var inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
result[outputIdx++] = xValues[inputIdx];
}
}
}
}
return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
}
var depthToSpaceConfig = {
kernelName: DepthToSpace,
backendName: 'cpu',
kernelFunc: depthToSpace$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNative(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([x, filter], 'depthwiseConv2DNative');
var xStrides = computeStrides(x.shape);
var filterStrides = computeStrides(filter.shape);
var $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
return 'Error in depthwiseConv2d: Either strides or dilations must be ' + ("1. Got strides " + strides + " and dilations '" + $dilations + "'");
});
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true
/* depthwise */
);
var filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
dilationHeight = convInfo.dilationHeight,
dilationWidth = convInfo.dilationWidth,
padInfo = convInfo.padInfo;
var padLeft = padInfo.left;
var padTop = padInfo.top;
var chMul = convInfo.outChannels / convInfo.inChannels;
var y = new TensorBuffer(convInfo.outShape, x.dtype);
var xVals = backend.data.get(x.dataId).values;
var wVals = backend.data.get(filter.dataId).values;
var yVals = y.values;
for (var b = 0; b < convInfo.batchSize; ++b) {
var xOffset1 = b * xStrides[0];
var yOffset1 = b * y.strides[0];
for (var yR = 0; yR < convInfo.outHeight; ++yR) {
var yOffset2 = yOffset1 + yR * y.strides[1];
var xRCorner = yR * convInfo.strideHeight - padTop;
for (var wR = 0; wR < filterHeight; ++wR) {
var xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
var wOffset1 = wR * filterStrides[0];
var xOffset2 = xOffset1 + xR * xStrides[1];
for (var yC = 0; yC < convInfo.outWidth; ++yC) {
var yOffset3 = yOffset2 + yC * y.strides[2];
var xCCorner = yC * convInfo.strideWidth - padLeft;
for (var wC = 0; wC < filterWidth; ++wC) {
var xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
var wOffset2 = wOffset1 + wC * filterStrides[1];
var xOffset3 = xOffset2 + xC * convInfo.inChannels;
var yOffset4 = yOffset3;
var wOffset3 = wOffset2;
for (var d1 = 0; d1 < convInfo.inChannels; ++d1) {
var xVal = xVals[xOffset3 + d1];
for (var q = 0; q < chMul; ++q) {
yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
}
yOffset4 += chMul;
wOffset3 += chMul;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
var depthwiseConv2dNativeConfig = {
kernelName: DepthwiseConv2dNative,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNative
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropFilter$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
dilations = attrs.dilations,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
filterShape = attrs.filterShape;
assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
var convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true
/* depthwise */
);
var strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth;
var dW = new TensorBuffer(convInfo.filterShape, 'float32');
var leftPad = convInfo.padInfo.left;
var topPad = convInfo.padInfo.top;
var chMul = convInfo.outChannels / convInfo.inChannels;
var xVals = backend.data.get(x.dataId).values;
var xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
var dyVals = backend.data.get(dy.dataId).values;
var dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
for (var wR = 0; wR < filterHeight; ++wR) {
var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (var wC = 0; wC < filterWidth; ++wC) {
var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (var d2 = 0; d2 < convInfo.outChannels; ++d2) {
var d1 = Math.trunc(d2 / chMul);
var dm = d2 % chMul;
var dotProd = 0;
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var yR = yRMin; yR < yRMax; ++yR) {
var xR = wR + yR * strideHeight - topPad;
for (var yC = yCMin; yC < yCMax; ++yC) {
var xC = wC + yC * strideWidth - leftPad;
dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2);
}
}
}
dW.set(dotProd, wR, wC, d1, dm);
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
var depthwiseConv2dNativeBackpropFilterConfig = {
kernelName: DepthwiseConv2dNativeBackpropFilter,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNativeBackpropFilter$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropInput$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var strides = attrs.strides,
dilations = attrs.dilations,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
inputShape = attrs.inputShape;
assertNotComplex([dy, filter], 'depthwiseConv2DNativeBackpropInput');
var dyStrides = computeStrides(dy.shape);
var filterStrides = computeStrides(filter.shape);
var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true
/* depthwise */
);
var dx = new TensorBuffer(convInfo.inShape, 'float32');
var dxValues = dx.values;
var _dx$strides = dx.strides,
dxS0 = _dx$strides[0],
dxS1 = _dx$strides[1],
dxS2 = _dx$strides[2];
var dyValues = backend.data.get(dy.dataId).values;
var dyS0 = dyStrides[0],
dyS1 = dyStrides[1],
dyS2 = dyStrides[2];
var fltValues = backend.data.get(filter.dataId).values;
var fltS0 = filterStrides[0],
fltS1 = filterStrides[1],
fltS2 = filterStrides[2];
var batchSize = convInfo.batchSize,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
inChannels = convInfo.inChannels,
inHeight = convInfo.inHeight,
inWidth = convInfo.inWidth,
outChannels = convInfo.outChannels,
outHeight = convInfo.outHeight,
outWidth = convInfo.outWidth,
strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth;
var topPad = filterHeight - 1 - convInfo.padInfo.top;
var leftPad = filterWidth - 1 - convInfo.padInfo.left;
var chMul = outChannels / inChannels;
for (var b = 0; b < batchSize; ++b) {
for (var d1 = 0; d1 < inChannels; ++d1) {
for (var xR = 0; xR < inHeight; ++xR) {
var xRCorner = xR - topPad;
var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (var xC = 0; xC < inWidth; ++xC) {
var xCCorner = xC - leftPad;
var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
var dotProd = 0;
for (var yR = xRMin; yR < yRMax; ++yR) {
var wR = yR * strideHeight - xRCorner;
for (var yC = xCMin; yC < yCMax; ++yC) {
var wC = yC * strideWidth - xCCorner;
var dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
var fltOffset = fltS0 * (filterHeight - 1 - wR) + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (var dm = 0; dm < chMul; ++dm) {
var d2 = d1 * chMul + dm;
var pixel = dyValues[dyOffset + d2];
var weight = fltValues[fltOffset + dm];
dotProd += pixel * weight;
}
}
}
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var depthwiseConv2dNativeBackpropInputConfig = {
kernelName: DepthwiseConv2dNativeBackpropInput,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNativeBackpropInput$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function diag$1(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
var xSize = sizeFromShape(x.shape);
var xVals = backend.data.get(x.dataId).values;
var outBuf = buffer([xSize, xSize], x.dtype);
var vals = outBuf.values;
for (var i = 0; i < xVals.length; i++) {
vals[i * xSize + i] = xVals[i];
}
var outShape = [].concat(x.shape, x.shape);
return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
}
var diagConfig = {
kernelName: Diag,
backendName: 'cpu',
kernelFunc: diag$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dConfig = {
kernelName: Dilation2D,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend,
attrs = _ref.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
var cpuBackend = backend;
var xVals = cpuBackend.data.get(x.dataId).values;
var xRank = x.shape.length;
var filterVals = cpuBackend.data.get(filter.dataId).values;
var filterRank = filter.shape.length;
var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC'
/* dataFormat */
, dilations),
batchSize = _backend_util$compute.batchSize,
inHeight = _backend_util$compute.inHeight,
inWidth = _backend_util$compute.inWidth,
inChannels = _backend_util$compute.inChannels,
outHeight = _backend_util$compute.outHeight,
outWidth = _backend_util$compute.outWidth,
padInfo = _backend_util$compute.padInfo,
strideHeight = _backend_util$compute.strideHeight,
strideWidth = _backend_util$compute.strideWidth,
filterHeight = _backend_util$compute.filterHeight,
filterWidth = _backend_util$compute.filterWidth,
dilationHeight = _backend_util$compute.dilationHeight,
dilationWidth = _backend_util$compute.dilationWidth,
outShape = _backend_util$compute.outShape;
var outSize = sizeFromShape(outShape);
var outRank = outShape.length;
var outputVals = getArrayFromDType(x.dtype, outSize); // Upsampling the input by fill in `dilation size - 1` values between each
// input value.
// This implementation follows the TF c++ implementation:
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
var filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
var val = xVals[xIndex] + filterVals[filterIndex];
if (val > curVal) {
curVal = val;
}
}
}
}
}
var outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
outputVals[outputIndex] = curVal;
}
}
}
}
var dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
return {
dataId: dataId,
shape: outShape,
dtype: x.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dBackpropFilterConfig = {
kernelName: Dilation2DBackpropFilter,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend,
attrs = _ref.attrs;
var x = inputs.x,
filter = inputs.filter,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
var cpuBackend = backend;
var $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
var $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC'
/* dataFormat */
, dilations),
batchSize = _backend_util$compute.batchSize,
inHeight = _backend_util$compute.inHeight,
inWidth = _backend_util$compute.inWidth,
inChannels = _backend_util$compute.inChannels,
outHeight = _backend_util$compute.outHeight,
outWidth = _backend_util$compute.outWidth,
padInfo = _backend_util$compute.padInfo,
strideHeight = _backend_util$compute.strideHeight,
strideWidth = _backend_util$compute.strideWidth,
filterHeight = _backend_util$compute.filterHeight,
filterWidth = _backend_util$compute.filterWidth,
dilationHeight = _backend_util$compute.dilationHeight,
dilationWidth = _backend_util$compute.dilationWidth,
outShape = _backend_util$compute.outShape;
assert(dy.rank === outShape.length, function () {
return "Error in " + Dilation2DBackpropFilter + ", dy " + ("must have the same rank as output " + outShape.length + ", but got ") + ("" + dy.rank);
});
var $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values); // The computed filter gradients has the same dimensions as the filter:
// [filterHeight, filterWidth, depth]
var gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype); // In the case of multiple argmax branches, we only back-propagate along the
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
// similarly to the max-pooling backward routines.
// This implementation follows the TF c++ implementation:
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
var hMax = 0;
var wMax = 0;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hMax = h;
wMax = w;
}
}
}
}
}
gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
var dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
return {
dataId: dataId,
shape: filter.shape,
dtype: filter.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var dilation2dBackpropInputConfig = {
kernelName: Dilation2DBackpropInput,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend,
attrs = _ref.attrs;
var x = inputs.x,
filter = inputs.filter,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
var cpuBackend = backend;
var $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
var $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
var _backend_util$compute = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC'
/* dataFormat */
, dilations),
batchSize = _backend_util$compute.batchSize,
inHeight = _backend_util$compute.inHeight,
inWidth = _backend_util$compute.inWidth,
inChannels = _backend_util$compute.inChannels,
outHeight = _backend_util$compute.outHeight,
outWidth = _backend_util$compute.outWidth,
padInfo = _backend_util$compute.padInfo,
strideHeight = _backend_util$compute.strideHeight,
strideWidth = _backend_util$compute.strideWidth,
filterHeight = _backend_util$compute.filterHeight,
filterWidth = _backend_util$compute.filterWidth,
dilationHeight = _backend_util$compute.dilationHeight,
dilationWidth = _backend_util$compute.dilationWidth,
outShape = _backend_util$compute.outShape;
assert(dy.rank === outShape.length, function () {
return "Error in " + Dilation2DBackpropInput + ", dy " + ("must have the same rank as output " + outShape.length + ", but got ") + ("" + dy.rank);
});
var $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values); // The computed gradients has the same dimensions as the input:
// [batch, inputHeight, inputCols, inChannel]
var gradients = makeZerosNestedTypedArray(x.shape, x.dtype); // In the case of multiple argmax branches, we only back-propagate along the
// last branch, i.e., the one with largest value of `h * filter_cols + w`,
// similarly to the max-pooling backward routines.
// This implementation follows the TF c++ implementation:
// https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc
for (var b = 0; b < batchSize; ++b) {
for (var hOut = 0; hOut < outHeight; ++hOut) {
var hBeg = hOut * strideHeight - padInfo.top;
for (var wOut = 0; wOut < outWidth; ++wOut) {
var wBeg = wOut * strideWidth - padInfo.left;
for (var d = 0; d < inChannels; ++d) {
var curVal = Number.MIN_SAFE_INTEGER;
var hInMax = hBeg < 0 ? 0 : hBeg;
var wInMax = wBeg < 0 ? 0 : wBeg;
for (var h = 0; h < filterHeight; ++h) {
var hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (var w = 0; w < filterWidth; ++w) {
var wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
var val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hInMax = hIn;
wInMax = wIn;
}
}
}
}
}
gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
var dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
return {
dataId: dataId,
shape: x.shape,
dtype: x.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sum$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
assertNotComplex(x, 'sum');
var $x;
if (x.dtype === 'bool') {
$x = cast$2({
inputs: {
x: x
},
backend: backend,
attrs: {
dtype: 'int32'
}
});
} else {
$x = identity$1({
inputs: {
x: x
},
backend: backend
});
}
var xRank = $x.shape.length;
var axes = parseAxisParam(axis, $x.shape);
var permutation = getAxesPermutation(axes, xRank);
var reductionAxes = axes;
var permutedX = $x;
if (permutation != null) {
permutedX = transpose$1({
inputs: {
x: $x
},
backend: backend,
attrs: {
perm: permutation
}
});
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
}
assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, reductionAxes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var resultDtype = upcastType(permutedX.dtype, 'int32');
var result = zeros$2(backend, outShape, resultDtype);
var reduceSize = sizeFromShape(reduceShape);
var vals = backend.data.get(result.dataId).values;
var aVals = backend.data.get(permutedX.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var _sum = 0;
for (var j = 0; j < reduceSize; ++j) {
_sum += aVals[offset + j];
}
vals[i] = _sum;
}
if (keepDims) {
var newShape = expandShapeToKeepDim(result.shape, axes);
var oldResult = result;
result = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: newShape
}
});
backend.disposeIntermediateTensorInfo(oldResult);
}
backend.disposeIntermediateTensorInfo($x);
if (permutation != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return result;
}
var sumConfig = {
kernelName: Sum,
backendName: 'cpu',
kernelFunc: sum$3
};
function einsum$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var equation = attrs.equation;
var tensors = inputs;
var _backend_util$decodeE = decodeEinsumEquation(equation, tensors.length),
allDims = _backend_util$decodeE.allDims,
summedDims = _backend_util$decodeE.summedDims,
idDims = _backend_util$decodeE.idDims;
checkEinsumDimSizes(allDims.length, idDims, tensors);
var _backend_util$getEins = getEinsumComputePath(summedDims, idDims),
path = _backend_util$getEins.path,
steps = _backend_util$getEins.steps;
var nSteps = steps.length;
var out = null;
var numDimsRemaining = allDims.length;
var tensorsToDispose = [];
for (var i = 0; i < nSteps; ++i) {
for (var _iterator = _createForOfIteratorHelperLoose(steps[i]), _step; !(_step = _iterator()).done;) {
var idTerm = _step.value;
var _backend_util$getEins2 = getEinsumPermutation(numDimsRemaining, idDims[idTerm]),
perm = _backend_util$getEins2.permutationIndices,
dimsToExpand = _backend_util$getEins2.expandDims;
var x = void 0;
if (isIdentityPermutation(perm)) {
x = tensors[idTerm];
} else {
x = transpose$1({
inputs: {
x: tensors[idTerm]
},
backend: backend,
attrs: {
perm: perm
}
});
tensorsToDispose.push(x);
}
var targetShape = x.shape.slice();
for (var k = 0; k < dimsToExpand.length; ++k) {
targetShape.splice(dimsToExpand[k], 0, 1);
}
if (!arraysEqual(x.shape, targetShape)) {
x = reshape$2({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: targetShape
}
});
tensorsToDispose.push(x);
}
if (out === null) {
out = x;
} else {
// tslint:disable-next-line: no-unnecessary-type-assertion
out = multiply$3({
inputs: {
a: x,
b: out
},
backend: backend
});
tensorsToDispose.push(out);
}
}
if (i < nSteps - 1) {
if (path[i] >= 0) {
out = sum$3({
inputs: {
x: out
},
backend: backend,
attrs: {
axis: path[i] - (allDims.length - numDimsRemaining),
keepDims: false
}
});
tensorsToDispose.push(out);
}
numDimsRemaining--;
}
} // Clean up intermediate tensors.
for (var _i = 0, _tensorsToDispose = tensorsToDispose; _i < _tensorsToDispose.length; _i++) {
var tensorInfo = _tensorsToDispose[_i];
if (tensorInfo === out) {
continue;
}
backend.disposeIntermediateTensorInfo(tensorInfo);
}
return out;
}
var einsumConfig = {
kernelName: Einsum,
backendName: 'cpu',
kernelFunc: einsum$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function eluGrad(args) {
var inputs = args.inputs,
backend = args.backend;
var dy = inputs.dy,
y = inputs.y;
assertNotComplex([dy, y], 'eluGrad');
var resultValues = new Float32Array(sizeFromShape(y.shape));
var values = backend.data.get(y.dataId).values;
var dyValues = backend.data.get(dy.dataId).values;
for (var i = 0; i < values.length; ++i) {
var v = values[i];
if (v >= 1) {
resultValues[i] = dyValues[i];
} else {
resultValues[i] = dyValues[i] * (v + 1);
}
}
return backend.makeTensorInfo(y.shape, 'float32', resultValues);
}
var eluGradConfig$1 = {
kernelName: EluGrad,
backendName: 'cpu',
kernelFunc: eluGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var p = ERF_P;
var a1 = ERF_A1;
var a2 = ERF_A2;
var a3 = ERF_A3;
var a4 = ERF_A4;
var a5 = ERF_A5;
var erf$1 = unaryKernelFunc(Erf, function (xi) {
var sign = Math.sign(xi);
var v = Math.abs(xi);
var t = 1.0 / (1.0 + p * v);
return sign * (1.0 - ((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t * Math.exp(-v * v));
});
var erfConfig = {
kernelName: Erf,
backendName: 'cpu',
kernelFunc: erf$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function expandDims$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var input = inputs.input;
var dim = attrs.dim;
var inputRank = input.shape.length;
var newShape = input.shape.slice();
var $dim = dim;
if (dim < 0) {
// Negative value is counted from the tail of rank.
assert(-(inputRank + 1) <= dim, function () {
return "Axis must be in the interval [" + -(inputRank + 1) + ", " + inputRank + "]";
});
$dim = inputRank + dim + 1;
}
newShape.splice($dim, 0, 1);
return reshape$2({
inputs: {
x: input
},
backend: backend,
attrs: {
shape: newShape
}
});
}
var expandDimsConfig = {
kernelName: ExpandDims,
backendName: 'cpu',
kernelFunc: expandDims$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var realDivImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a / b;
});
var div$1 = binaryKernelFunc(RealDiv, realDivImpl);
var realDivConfig = {
kernelName: RealDiv,
backendName: 'cpu',
kernelFunc: div$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Calculate FFT of inner most elements of batch tensor.
*/
function fftBatch(input, inverse, cpuBackend) {
var inputShape = input.shape;
var batch = inputShape[0];
var innerDim = inputShape[1];
var inputVals = cpuBackend.data.get(input.dataId);
var real2D = inputVals.complexTensorInfos.real;
var imag2D = inputVals.complexTensorInfos.imag; // Collects real and imaginary values separately.
var resultShape = [batch, innerDim];
var resultSize = sizeFromShape(resultShape);
var resultReal = getTypedArrayFromDType('float32', resultSize);
var resultImag = getTypedArrayFromDType('float32', resultSize);
for (var b = 0; b < batch; b++) {
// TODO: Support slice ops for complex type.
var r = slice$3({
inputs: {
x: real2D
},
backend: cpuBackend,
attrs: {
begin: [b, 0],
size: [1, innerDim]
}
});
var i = slice$3({
inputs: {
x: imag2D
},
backend: cpuBackend,
attrs: {
begin: [b, 0],
size: [1, innerDim]
}
});
var _input = complex$1({
inputs: {
real: r,
imag: i
},
backend: cpuBackend
}); // Run FFT by batch element.
var _fftImpl = fftImpl(_input, inverse, cpuBackend),
_real = _fftImpl.real,
_imag = _fftImpl.imag;
var res = mergeRealAndImagArrays(_real, _imag);
for (var d = 0; d < innerDim; d++) {
var c = getComplexWithIndex(res, d);
resultReal[b * innerDim + d] = c.real;
resultImag[b * innerDim + d] = c.imag;
}
cpuBackend.disposeIntermediateTensorInfo(r);
cpuBackend.disposeIntermediateTensorInfo(i);
cpuBackend.disposeIntermediateTensorInfo(_input);
}
var $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
var $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
var result = complex$1({
inputs: {
real: $realInfo,
imag: $imagInfo
},
backend: cpuBackend
});
cpuBackend.disposeIntermediateTensorInfo($realInfo);
cpuBackend.disposeIntermediateTensorInfo($imagInfo);
return result;
}
function fftImpl(input, inverse, cpuBackend) {
var inputSize = sizeFromShape(input.shape);
var inputVals = cpuBackend.data.get(input.dataId);
var realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
var imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
if (isExponentOf2(inputSize)) {
var result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
var resultShape = [input.shape[0], input.shape[1]];
if (inverse) {
var realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
var imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
var sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
var sizeInfoCopy = identity$1({
inputs: {
x: sizeInfo
},
backend: cpuBackend
});
var divRealInfo = realDivConfig.kernelFunc({
inputs: {
a: realInfo,
b: sizeInfo
},
backend: cpuBackend
});
var divImagInfo = realDivConfig.kernelFunc({
inputs: {
a: imagInfo,
b: sizeInfoCopy
},
backend: cpuBackend
});
var divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
var divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(realInfo);
cpuBackend.disposeIntermediateTensorInfo(imagInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
return {
real: divRealVals,
imag: divImagVals
};
}
return result;
} else {
var data = mergeRealAndImagArrays(realVals, imagVals);
var rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
return splitRealAndImagArrays(rawOutput);
}
}
function isExponentOf2(size) {
return (size & size - 1) === 0;
} // FFT using Cooley-Tukey algorithm on radix 2 dimensional input.
function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
if (size === 1) {
return {
real: realVals,
imag: imagVals
};
}
var data = mergeRealAndImagArrays(realVals, imagVals);
var half = size / 2;
var evenComplex = complexWithEvenIndex(data);
var evenRealVals = evenComplex.real;
var evenImagVals = evenComplex.imag;
var evenShape = [evenRealVals.length];
var evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
var evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
var evenTensorInfo = complex$1({
inputs: {
real: evenRealInfo,
imag: evenImagInfo
},
backend: cpuBackend
});
var oddComplex = complexWithOddIndex(data);
var oddRealVals = oddComplex.real;
var oddImagVals = oddComplex.imag;
var oddShape = [oddRealVals.length];
var oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
var oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
var oddTensorInfo = complex$1({
inputs: {
real: oddRealInfo,
imag: oddImagInfo
},
backend: cpuBackend
}); // Recursive call for half part of original input.
var $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
var $evenRealVals = $evenComplex.real;
var $evenImagVals = $evenComplex.imag;
var $evenShape = [$evenRealVals.length];
var $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
var $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
var $evenTensorInfo = complex$1({
inputs: {
real: $evenRealInfo,
imag: $evenImagInfo
},
backend: cpuBackend
});
var $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
var $oddRealVals = $oddComplex.real;
var $oddImagVals = $oddComplex.imag;
var $oddShape = [$oddRealVals.length];
var $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
var $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
var $oddTensorInfo = complex$1({
inputs: {
real: $oddRealInfo,
imag: $oddImagInfo
},
backend: cpuBackend
});
var e = exponents(size, inverse);
var eShape = [e.real.length];
var eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
var eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
var complexInfo = complex$1({
inputs: {
real: eRealInfo,
imag: eImagInfo
},
backend: cpuBackend
});
var exponentInfo = multiply$3({
inputs: {
a: complexInfo,
b: $oddTensorInfo
},
backend: cpuBackend
});
var addPart = add$4({
inputs: {
a: $evenTensorInfo,
b: exponentInfo
},
backend: cpuBackend
});
var subPart = sub$1({
inputs: {
a: $evenTensorInfo,
b: exponentInfo
},
backend: cpuBackend
});
var addPartReal = real$1({
inputs: {
input: addPart
},
backend: cpuBackend
});
var subPartReal = real$1({
inputs: {
input: subPart
},
backend: cpuBackend
});
var addPartImag = imag$1({
inputs: {
input: addPart
},
backend: cpuBackend
});
var subPartImag = imag$1({
inputs: {
input: subPart
},
backend: cpuBackend
});
var $real = concat$1({
inputs: [addPartReal, subPartReal],
backend: cpuBackend,
attrs: {
axis: 0
}
});
var $imag = concat$1({
inputs: [addPartImag, subPartImag],
backend: cpuBackend,
attrs: {
axis: 0
}
});
var $realVals = cpuBackend.data.get($real.dataId).values;
var $imagVals = cpuBackend.data.get($imag.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
cpuBackend.disposeIntermediateTensorInfo(complexInfo);
cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
cpuBackend.disposeIntermediateTensorInfo(addPart);
cpuBackend.disposeIntermediateTensorInfo(subPart);
cpuBackend.disposeIntermediateTensorInfo(addPartReal);
cpuBackend.disposeIntermediateTensorInfo(addPartImag);
cpuBackend.disposeIntermediateTensorInfo(subPartReal);
cpuBackend.disposeIntermediateTensorInfo(subPartImag);
cpuBackend.disposeIntermediateTensorInfo($real);
cpuBackend.disposeIntermediateTensorInfo($imag);
return {
real: $realVals,
imag: $imagVals
};
} // Calculate fourier transform by multplying sinusoid matrix.
function fourierTransformByMatmul(data, size, inverse) {
var ret = new Float32Array(size * 2); // TODO: Use matmul instead once it supports complex64 type.
for (var r = 0; r < size; r++) {
var _real2 = 0.0;
var _imag2 = 0.0;
for (var c = 0; c < size; c++) {
var e = exponent(r * c, size, inverse);
var term = getComplexWithIndex(data, c);
_real2 += term.real * e.real - term.imag * e.imag;
_imag2 += term.real * e.imag + term.imag * e.real;
}
if (inverse) {
_real2 /= size;
_imag2 /= size;
}
assignToTypedArray(ret, _real2, _imag2, r);
}
return ret;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fft$1(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var inputSize = sizeFromShape(input.shape); // Collapse all outer dimensions to a single batch dimension.
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape$2({
inputs: {
x: input
},
backend: backend,
attrs: {
shape: [batch, innerDimensionSize]
}
});
var result = fftBatch(input2D, false, backend);
var resultReshaped = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: input.shape
}
});
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var fftConfig = {
kernelName: FFT,
backendName: 'cpu',
kernelFunc: fft$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fill$1(args) {
var backend = args.backend,
attrs = args.attrs;
var shape = attrs.shape,
value = attrs.value,
dtype = attrs.dtype;
var $dtype = dtype || inferDtype(value);
var values = getArrayFromDType($dtype, sizeFromShape(shape));
fillValues(values, value, $dtype);
return backend.makeTensorInfo(shape, $dtype, values);
}
var fillConfig = {
kernelName: Fill,
backendName: 'cpu',
kernelFunc: fill$1
};
function fillValues(values, value, dtype) {
if (dtype === 'string') {
values.fill(value);
} else {
values.fill(value);
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var flipLeftRightConfig = {
kernelName: FlipLeftRight,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var image = inputs.image;
var cpuBackend = backend;
var output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
var _image$shape = image.shape,
batch = _image$shape[0],
imageHeight = _image$shape[1],
imageWidth = _image$shape[2],
numChannels = _image$shape[3];
var imageVals = cpuBackend.data.get(image.dataId).values;
for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (var row = 0; row < imageHeight; row++) {
var rowOffset = row * (imageWidth * numChannels);
for (var col = 0; col < imageWidth; col++) {
var colOffset = col * numChannels;
for (var channel = 0; channel < numChannels; channel++) {
var coordX = Math.round(imageWidth - col - 1);
var outIdx = batchOffset + rowOffset + colOffset + channel;
var outputValue = imageVals[outIdx]; // If the coordinate position falls within the image boundaries...
if (coordX >= 0 && coordX < imageWidth) {
// set the output to the image value at the coordinate position.
var rotatedColOffset = coordX * numChannels;
var imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
output[outIdx] = outputValue;
}
}
}
}
var dataId = cpuBackend.write(output, image.shape, image.dtype);
return {
dataId: dataId,
shape: image.shape,
dtype: image.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var floorDivImpl = createSimpleBinaryKernelImpl(function (a, b) {
return Math.floor(a / b);
});
var floorDiv$1 = binaryKernelFunc(FloorDiv, floorDivImpl, null
/* complexImpl */
, 'int32');
var floorDivConfig = {
kernelName: FloorDiv,
backendName: 'cpu',
kernelFunc: floorDiv$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConv2D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
var result = conv2D({
inputs: {
x: x,
filter: filter
},
backend: backend,
attrs: {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode
}
});
if (bias) {
var resultOld = result;
result = add$4({
inputs: {
a: result,
b: bias
},
backend: backend
});
backend.disposeIntermediateTensorInfo(resultOld);
}
if (activation) {
var _resultOld = result;
result = applyActivation$1(backend, result, activation, preluActivationWeights, leakyreluAlpha);
backend.disposeIntermediateTensorInfo(_resultOld);
}
return result;
}
var fusedConv2DConfig = {
kernelName: FusedConv2D,
backendName: 'cpu',
kernelFunc: fusedConv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedDepthwiseConv2D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
var result = depthwiseConv2dNative({
inputs: {
x: x,
filter: filter
},
backend: backend,
attrs: {
strides: strides,
pad: pad,
dataFormat: dataFormat,
dilations: dilations,
dimRoundingMode: dimRoundingMode
}
});
if (bias) {
var oldResult = result;
result = add$4({
inputs: {
a: result,
b: bias
},
backend: backend
});
backend.disposeIntermediateTensorInfo(oldResult);
}
if (activation) {
var _oldResult = result;
result = applyActivation$1(backend, result, activation, preluActivationWeights, leakyreluAlpha);
backend.disposeIntermediateTensorInfo(_oldResult);
}
return result;
}
var fusedDepthwiseConv2DConfig = {
kernelName: FusedDepthwiseConv2D,
backendName: 'cpu',
kernelFunc: fusedDepthwiseConv2D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherNd(args) {
var inputs = args.inputs,
backend = args.backend;
var params = inputs.params,
indices = inputs.indices;
var paramsSize = sizeFromShape(params.shape);
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
var _backend_util$prepare = prepareAndValidate(params, indices),
resultShape = _backend_util$prepare[0],
numSlices = _backend_util$prepare[1],
sliceSize = _backend_util$prepare[2],
strides = _backend_util$prepare[3];
if (numSlices === 0) {
return backend.makeTensorInfo(resultShape, params.dtype, []);
}
var indicesData = backend.data.get(indices.dataId).values;
var paramsBuf = backend.bufferSync(params);
var outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
}
var gatherNdConfig = {
kernelName: GatherNd,
backendName: 'cpu',
kernelFunc: gatherNd
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherV2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
indices = inputs.indices;
var axis = attrs.axis,
batchDims = attrs.batchDims;
assertNotComplex([x, indices], 'gatherV2');
var $batchDims = batchDims;
if (batchDims == null) {
$batchDims = 0;
}
var indicesSize = sizeFromShape(indices.shape);
var parsedAxis = parseAxisParam(axis, x.shape)[0];
var shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
var flattenX = reshape$2({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: [shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize]
}
});
var flattenIndex = reshape$2({
inputs: {
x: indices
},
backend: backend,
attrs: {
shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize]
}
});
var flattenOutputShape = [shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize];
var indicesBuf = backend.bufferSync(flattenIndex);
var xBuf = backend.bufferSync(flattenX);
var outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(flattenIndex);
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
}
var gatherV2Config = {
kernelName: GatherV2,
backendName: 'cpu',
kernelFunc: gatherV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ifft$1(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var inputSize = sizeFromShape(input.shape); // Collapse all outer dimensions to a single batch dimension.
var innerDimensionSize = input.shape[input.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape$2({
inputs: {
x: input
},
backend: backend,
attrs: {
shape: [batch, innerDimensionSize]
}
});
var result = fftBatch(input2D, true, backend);
var resultReshaped = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: input.shape
}
});
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var ifftConfig = {
kernelName: IFFT,
backendName: 'cpu',
kernelFunc: ifft$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isFinite$2 = unaryKernelFunc(IsFinite, function (xi) {
return Number.isFinite(xi) ? 1 : 0;
}, 'bool');
var isFiniteConfig = {
kernelName: IsFinite,
backendName: 'cpu',
kernelFunc: isFinite$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isInf$1 = unaryKernelFunc(IsInf, function (xi) {
return Math.abs(xi) === Infinity ? 1 : 0;
}, 'bool');
var isInfConfig = {
kernelName: IsInf,
backendName: 'cpu',
kernelFunc: isInf$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var isNaN$2 = unaryKernelFunc(IsNan, function (xi) {
return Number.isNaN(xi) ? 1 : 0;
}, 'bool');
var isNaNConfig = {
kernelName: IsNan,
backendName: 'cpu',
kernelFunc: isNaN$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function linSpace(args) {
var backend = args.backend,
attrs = args.attrs;
var start = attrs.start,
stop = attrs.stop,
num = attrs.num;
var outVals = linSpaceImpl(start, stop, num);
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
}
var linSpaceConfig = {
kernelName: LinSpace,
backendName: 'cpu',
kernelFunc: linSpace
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var log1p$1 = unaryKernelFunc(Log1p, function (xi) {
return Math.log1p(xi);
});
var log1pConfig = {
kernelName: Log1p,
backendName: 'cpu',
kernelFunc: log1p$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logicalAndImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a && b;
});
var logicalAnd$1 = binaryKernelFunc(LogicalAnd, logicalAndImpl, null
/* complexImpl */
, 'bool');
var logicalAndConfig = {
kernelName: LogicalAnd,
backendName: 'cpu',
kernelFunc: logicalAnd$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logicalNot$1 = unaryKernelFunc(LogicalNot, function (xi) {
return xi ? 0 : 1;
}, 'bool');
var logicalNotConfig = {
kernelName: LogicalNot,
backendName: 'cpu',
kernelFunc: logicalNot$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var logicalOrImpl = createSimpleBinaryKernelImpl(function (a, b) {
return a || b;
});
var logicalOr$1 = binaryKernelFunc(LogicalOr, logicalOrImpl, null
/* complexImpl */
, 'bool');
var logicalOrConfig = {
kernelName: LogicalOr,
backendName: 'cpu',
kernelFunc: logicalOr$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function lRN(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var depthRadius = attrs.depthRadius,
bias = attrs.bias,
alpha = attrs.alpha,
beta = attrs.beta;
assertNotComplex(x, 'LRN');
var channels = x.shape[3];
var maxD = channels - 1;
var xValues = backend.data.get(x.dataId).values;
var size = sizeFromShape(x.shape);
var result = new Float32Array(size);
function sumAcrossChannels(offset) {
var currentChannel = offset % channels;
var beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
var endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
var sum = 0.0;
for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
var z = xValues[beginSumOffset];
sum += z * z;
}
return sum;
}
for (var offset = 0; offset < size; offset++) {
var sum = sumAcrossChannels(offset);
var val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
result[offset] = val;
}
return backend.makeTensorInfo(x.shape, x.dtype, result);
}
var lRNConfig = {
kernelName: LRN,
backendName: 'cpu',
kernelFunc: lRN
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function lRNGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
y = inputs.y,
dy = inputs.dy;
var depthRadius = attrs.depthRadius,
bias = attrs.bias,
alpha = attrs.alpha,
beta = attrs.beta;
assertNotComplex(dy, 'LRNGrad');
var dySize = sizeFromShape(dy.shape);
var channels = dy.shape[3];
var dyValues = backend.data.get(dy.dataId).values;
var xValues = backend.data.get(x.dataId).values;
var yValues = backend.data.get(y.dataId).values;
var result = new Float32Array(dySize);
var size = dySize;
for (var offset = 0; offset < size; offset++) {
var currentChannel = offset % channels;
var depthBegin = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
var depthEnd = offset - currentChannel + Math.min(channels, currentChannel + depthRadius + 1);
var norm = 0;
for (var k = depthBegin; k < depthEnd; k++) {
norm += Math.pow(xValues[k], 2);
}
norm = alpha * norm + bias;
for (var _k = depthBegin; _k < depthEnd; _k++) {
var dyi = -2 * alpha * beta * xValues[_k] * yValues[offset] / norm;
if (offset === _k) {
dyi += Math.pow(norm, -beta);
}
dyi *= dyValues[offset];
result[_k] += dyi;
}
}
return backend.makeTensorInfo(dy.shape, x.dtype, result);
}
var lRNGradConfig = {
kernelName: LRNGrad,
backendName: 'cpu',
kernelFunc: lRNGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function max$7(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var reductionIndices = attrs.reductionIndices,
keepDims = attrs.keepDims;
var cpuBackend = backend;
var xShape = x.shape;
var xRank = xShape.length;
var origAxes = parseAxisParam(reductionIndices, xShape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var xVals = cpuBackend.data.get(x.dataId).values;
if (permutedAxes != null) {
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}
xVals = transposeImpl(xVals, xShape, x.dtype, permutedAxes, newShape);
axes = getInnerMostAxes(axes.length, xRank);
xShape = newShape;
}
assertNotComplex(x, 'max');
assertAxesAreInnerMostDims('max', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(xShape, axes),
maxOutShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var reduceSize = sizeFromShape(reduceShape);
var result = maxImpl(xVals, reduceSize, maxOutShape, x.dtype);
var dataId = cpuBackend.write(result, maxOutShape, x.dtype);
var outShape = maxOutShape;
if (keepDims) {
// reshape
var _newShape = expandShapeToKeepDim(maxOutShape, origAxes);
outShape = _newShape;
}
return {
dataId: dataId,
shape: outShape,
dtype: x.dtype
};
}
var maxConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: max$7
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
assertNotComplex(x, 'maxPool');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in maxPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity$1({
inputs: {
x: x
},
backend: backend
});
} else {
var xValues = backend.data.get(x.dataId).values;
var _strides = computeStrides(x.shape);
var buffer = pool$1(xValues, x.shape, x.dtype, _strides, convInfo, 'max');
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
var maxPoolConfig = {
kernelName: MaxPool,
backendName: 'cpu',
kernelFunc: maxPool$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
dataFormat = attrs.dataFormat;
assertNotComplex(x, 'maxPool3d');
var convInfo = computePool3DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode, dataFormat);
var xValues = backend.data.get(x.dataId).values;
var outBuf = pool3d$1(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
}
var maxPool3DConfig = {
kernelName: MaxPool3D,
backendName: 'cpu',
kernelFunc: maxPool3D
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3DGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
assertNotComplex([dy, input], 'maxPool3DGrad');
var convInfo = computePool3DInfo(input.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode);
var inputBuf = backend.bufferSync(input);
var maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = buffer(input.shape, 'float32');
var dyBuf = backend.bufferSync(dy);
for (var batch = 0; batch < convInfo.batchSize; ++batch) {
for (var channel = 0; channel < convInfo.inChannels; ++channel) {
for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
// Shader code begins
var dyDepthCorner = dxDepth - padFront;
var dyRowCorner = dxRow - padTop;
var dyColCorner = dxCol - padLeft;
var dotProd = 0;
for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
var dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth || Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
var dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight || Math.floor(dyRow) !== dyRow) {
continue;
}
for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
var dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth || Math.floor(dyCol) !== dyCol) {
continue;
}
var maxPos = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
var curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth + wRow * effectiveFilterWidth + wCol;
var mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel * mask;
}
}
}
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var maxPool3DGradConfig$1 = {
kernelName: MaxPool3DGrad,
backendName: 'cpu',
kernelFunc: maxPool3DGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input,
output = inputs.output;
var x = input;
assertNotComplex([input, output], 'maxPoolGrad');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode);
var xValues = backend.data.get(x.dataId).values;
var maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var dx = buffer(x.shape, 'float32');
var dyData = backend.data.get(dy.dataId).values;
var dyBuf = buffer(dy.shape, 'float32', dyData);
for (var b = 0; b < convInfo.batchSize; ++b) {
for (var d = 0; d < convInfo.inChannels; ++d) {
for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) {
// Shader code begins.
var dyRCorner = dxR - padTop;
var dyCCorner = dxC - padLeft;
var dotProd = 0;
for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
var dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight || Math.floor(dyR) !== dyR) {
continue;
}
for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
var dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth || Math.floor(dyC) !== dyC) {
continue;
}
var maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 - maxPosBuf.get(b, dyR, dyC, d);
var curPos = wR * effectiveFilterWidth + wC;
var mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
var pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel * mask;
}
}
dx.set(dotProd, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
var maxPoolGradConfig$1 = {
kernelName: MaxPoolGrad,
backendName: 'cpu',
kernelFunc: maxPoolGrad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
var strides = computeStrides(xShape);
var maxPools = pool$1(xValues, xShape, dtype, strides, convInfo, 'max');
var maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
return [maxPools.values, maxPositions.values];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolWithArgmaxConfig = {
kernelName: MaxPoolWithArgmax,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
includeBatchInIndex = attrs.includeBatchInIndex;
var cpuBackend = backend;
assertNotComplex(x, 'MaxPoolWithArgmax');
var values = cpuBackend.data.get(x.dataId).values;
var convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
var _maxPoolWithArgmaxImp = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo),
pooled = _maxPoolWithArgmaxImp[0],
indexes = _maxPoolWithArgmaxImp[1];
var pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
var indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
return [{
dataId: pooledDataId,
shape: convInfo.outShape,
dtype: x.dtype
}, {
dataId: indexesDataId,
shape: convInfo.outShape,
dtype: 'int32'
}];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mean$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
var axes = parseAxisParam(axis, x.shape);
var shapes = computeOutAndReduceShapes(x.shape, axes);
var reduceShape = shapes[1];
var reduceSize = sizeFromShape(reduceShape);
var toDispose = [];
var reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
toDispose.push(reduceSizeScalar);
var $x = cast$2({
inputs: {
x: x
},
backend: backend,
attrs: {
dtype: 'float32'
}
});
toDispose.push($x);
var res = div$1({
inputs: {
a: $x,
b: reduceSizeScalar
},
backend: backend
});
toDispose.push(res);
var result = sum$3({
inputs: {
x: res
},
backend: backend,
attrs: {
axis: axis,
keepDims: keepDims
}
});
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var meanConfig = {
kernelName: Mean,
backendName: 'cpu',
kernelFunc: mean$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function min$b(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
assertNotComplex(x, 'min');
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
if (permutedAxes != null) {
$x = transpose$1({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('min', axes, $x.shape.length);
var _backend_util$compute = computeOutAndReduceShapes($x.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var reduceSize = sizeFromShape(reduceShape);
var vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
var aVals = backend.data.get($x.dataId).values;
for (var i = 0; i < vals.length; ++i) {
var offset = i * reduceSize;
var _min = aVals[offset];
for (var j = 0; j < reduceSize; ++j) {
var value = aVals[offset + j];
if (Number.isNaN(value) || value < _min) {
// comparison with NaN always return false
_min = value;
}
}
vals[i] = _min;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
var result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
var expandedShape = expandShapeToKeepDim(outShape, origAxes);
var reshapedResult = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: expandedShape
}
});
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
var minConfig = {
kernelName: Min,
backendName: 'cpu',
kernelFunc: min$b
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function mirrorPad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var paddings = attrs.paddings,
mode = attrs.mode;
assertNotComplex(x, 'mirrorPad');
var outShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ x.shape[i] + p[1];
}
/* afterPad */
);
var start = paddings.map(function (p) {
return p[0];
});
var end = paddings.map(function (p, i) {
return p[0] + x.shape[i];
});
var offset = mode === 'reflect' ? 0 : 1;
var xVals = backend.data.get(x.dataId).values;
var xRank = x.shape.length;
var xStrides = computeStrides(x.shape);
var resultSize = sizeFromShape(outShape);
var resultRank = outShape.length;
var resultStrides = computeStrides(outShape);
var resVals = getTypedArrayFromDType(x.dtype, resultSize);
for (var i = 0; i < resultSize; i++) {
var coords = indexToLoc(i, resultRank, resultStrides);
for (var _i = 0; _i < resultRank; _i++) {
if (coords[_i] < start[_i]) {
coords[_i] = start[_i] * 2 - coords[_i] - offset;
} else if (coords[_i] >= end[_i]) {
coords[_i] = (end[_i] - 1) * 2 - coords[_i] + offset;
}
}
coords = coords.map(function (c, i) {
return c - start[i];
});
var inIndex = locToIndex(coords, xRank, xStrides);
resVals[i] = xVals[inIndex];
}
var outId = backend.write(resVals, outShape, x.dtype);
return {
dataId: outId,
shape: outShape,
dtype: x.dtype
};
}
var mirrorPadConfig = {
kernelName: MirrorPad,
backendName: 'cpu',
kernelFunc: mirrorPad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var modImpl = createSimpleBinaryKernelImpl(function (aValue, bValue) {
var rem = aValue % bValue;
if (aValue < 0 && bValue < 0 || aValue >= 0 && bValue >= 0) {
return rem;
} else {
return (rem + bValue) % bValue;
}
});
var mod$1 = binaryKernelFunc(Mod, modImpl);
var modConfig = {
kernelName: Mod,
backendName: 'cpu',
kernelFunc: mod$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function softmax$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var logits = inputs.logits;
var dim = attrs.dim;
var logitsRank = logits.shape.length;
var $dim = dim;
if ($dim === -1) {
$dim = logitsRank - 1;
}
if ($dim !== logitsRank - 1) {
throw Error('Softmax along a non-last dimension is not yet supported. ' + ("Logits was rank " + logitsRank + " and dim was " + $dim));
}
var axes = parseAxisParam([$dim], logits.shape);
var maxLogit = max$7({
inputs: {
x: logits
},
backend: backend,
attrs: {
reductionIndices: axes,
keepDims: false
}
});
var expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
var maxLogitReshaped = reshape$2({
inputs: {
x: maxLogit
},
backend: backend,
attrs: {
shape: expandedShape
}
});
var a = sub$1({
inputs: {
a: logits,
b: maxLogitReshaped
},
backend: backend
});
var b = exp$4({
inputs: {
x: a
},
backend: backend
});
var sumExp = sum$3({
inputs: {
x: b
},
backend: backend,
attrs: {
axis: axes,
keepDims: false
}
});
var sumReshaped = reshape$2({
inputs: {
x: sumExp
},
backend: backend,
attrs: {
shape: expandedShape
}
});
var result = div$1({
inputs: {
a: b,
b: sumReshaped
},
backend: backend
});
backend.disposeIntermediateTensorInfo(maxLogit);
backend.disposeIntermediateTensorInfo(maxLogitReshaped);
backend.disposeIntermediateTensorInfo(a);
backend.disposeIntermediateTensorInfo(b);
backend.disposeIntermediateTensorInfo(sumExp);
backend.disposeIntermediateTensorInfo(sumReshaped);
return result;
}
var softmaxConfig = {
kernelName: Softmax,
backendName: 'cpu',
kernelFunc: softmax$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function multinomial$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var logits = inputs.logits;
var numSamples = attrs.numSamples,
seed = attrs.seed,
normalized = attrs.normalized;
assertNotComplex(logits, 'multinomial');
var probabilities = normalized ? logits : softmax$2({
inputs: {
logits: logits
},
backend: backend,
attrs: {
dim: -1
}
});
var batchSize = probabilities.shape[0];
var numEvents = probabilities.shape[1];
var probVals = backend.data.get(probabilities.dataId).values;
var resShape = [batchSize, numSamples];
var resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
for (var b = 0; b < batchSize; ++b) {
var offset = b * numEvents; // The cdf won't include the last event. It will be implicit if no other
// event happened.
var cdf = new Float32Array(numEvents - 1);
cdf[0] = probVals[offset];
for (var event = 1; event < cdf.length; ++event) {
cdf[event] = cdf[event - 1] + probVals[offset + event];
}
var random = seedrandom_1(seed.toString());
var outOffset = b * numSamples;
for (var sampleId = 0; sampleId < numSamples; ++sampleId) {
var r = random(); // Assume last event happened by default.
resVals[outOffset + sampleId] = cdf.length;
for (var _event = 0; _event < cdf.length; _event++) {
if (r < cdf[_event]) {
resVals[outOffset + sampleId] = _event;
break;
}
}
}
}
if (!normalized) {
backend.disposeIntermediateTensorInfo(probabilities);
}
return backend.makeTensorInfo(resShape, 'int32', resVals);
}
var multinomialConfig = {
kernelName: Multinomial,
backendName: 'cpu',
kernelFunc: multinomial$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl;
function nonMaxSuppressionV3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold;
assertNotComplex(boxes, 'NonMaxSuppression');
var boxesVals = backend.data.get(boxes.dataId).values;
var scoresVals = backend.data.get(scores.dataId).values;
var _nonMaxSuppressionV3I = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold),
selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
}
var nonMaxSuppressionV3Config = {
kernelName: NonMaxSuppressionV3,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl;
function nonMaxSuppressionV4(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold,
padToMaxOutputSize = attrs.padToMaxOutputSize;
assertNotComplex(boxes, 'NonMaxSuppressionPadded');
var boxesVals = backend.data.get(boxes.dataId).values;
var scoresVals = backend.data.get(scores.dataId).values;
var _nonMaxSuppressionV4I = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize),
selectedIndices = _nonMaxSuppressionV4I.selectedIndices,
validOutputs = _nonMaxSuppressionV4I.validOutputs;
return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))];
}
var nonMaxSuppressionV4Config = {
kernelName: NonMaxSuppressionV4,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV4
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl;
function nonMaxSuppressionV5(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold,
softNmsSigma = attrs.softNmsSigma;
assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
var boxesVals = backend.data.get(boxes.dataId).values;
var scoresVals = backend.data.get(scores.dataId).values;
var maxOutputSizeVal = maxOutputSize;
var iouThresholdVal = iouThreshold;
var scoreThresholdVal = scoreThreshold;
var softNmsSigmaVal = softNmsSigma;
var _nonMaxSuppressionV5I = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal),
selectedIndices = _nonMaxSuppressionV5I.selectedIndices,
selectedScores = _nonMaxSuppressionV5I.selectedScores;
return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))];
}
var nonMaxSuppressionV5Config = {
kernelName: NonMaxSuppressionV5,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV5
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function oneHot$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var indices = inputs.indices;
var depth = attrs.depth,
onValue = attrs.onValue,
offValue = attrs.offValue;
assertNotComplex(indices, 'oneHot');
var indicesSize = sizeFromShape(indices.shape);
var res = new Float32Array(indicesSize * depth);
res.fill(offValue);
var indicesVal = backend.data.get(indices.dataId).values;
for (var event = 0; event < indicesSize; ++event) {
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
res[event * depth + indicesVal[event]] = onValue;
}
}
return backend.makeTensorInfo([].concat(indices.shape, [depth]), 'int32', res);
}
var oneHotConfig = {
kernelName: OneHot,
backendName: 'cpu',
kernelFunc: oneHot$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function zerosLike$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
if (x.dtype === 'string') {
throw new Error('zerosLike is not supported for string tensors');
} else if (x.dtype === 'complex64') {
var realPart = real$1({
inputs: {
input: x
},
backend: backend
});
var r = zerosLike$2({
inputs: {
x: realPart
},
backend: backend
});
var imagPart = imag$1({
inputs: {
input: x
},
backend: backend
});
var i = zerosLike$2({
inputs: {
x: imagPart
},
backend: backend
});
var result = complex$1({
inputs: {
real: r,
imag: i
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
} else {
return fill$1({
backend: backend,
attrs: {
shape: x.shape,
value: 0,
dtype: x.dtype
}
});
}
}
var zerosLikeConfig = {
kernelName: ZerosLike,
backendName: 'cpu',
kernelFunc: zerosLike$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function onesLike$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
if (x.dtype === 'string') {
throw new Error('onesLike is not supported for string tensors');
} else if (x.dtype === 'complex64') {
var realPart = real$1({
inputs: {
input: x
},
backend: backend
});
var r = onesLike$2({
inputs: {
x: realPart
},
backend: backend
});
var imagPart = imag$1({
inputs: {
input: x
},
backend: backend
});
var i = zerosLike$2({
inputs: {
x: imagPart
},
backend: backend
});
var result = complex$1({
inputs: {
real: r,
imag: i
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
} else {
return fill$1({
backend: backend,
attrs: {
shape: x.shape,
value: 1,
dtype: x.dtype
}
});
}
}
var onesLikeConfig = {
kernelName: OnesLike,
backendName: 'cpu',
kernelFunc: onesLike$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pack$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var axis = attrs.axis;
if (inputs.length === 1) {
return expandDims$2({
inputs: {
input: inputs[0]
},
backend: backend,
attrs: {
dim: axis
}
});
}
var shape = inputs[0].shape;
var dtype = inputs[0].dtype;
inputs.forEach(function (t) {
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
assert(dtype === t.dtype, function () {
return 'All tensors passed to stack must have matching dtypes';
});
});
var intermediateTensorInfos = [];
var expandedTensors = inputs.map(function (t) {
var expandedT = expandDims$2({
inputs: {
input: t
},
backend: backend,
attrs: {
dim: axis
}
});
intermediateTensorInfos.push(expandedT);
return expandedT;
});
var result = concat$1({
inputs: expandedTensors,
backend: backend,
attrs: {
axis: axis
}
});
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var packConfig = {
kernelName: Pack,
backendName: 'cpu',
kernelFunc: pack$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function padV2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var paddings = attrs.paddings,
constantValue = attrs.constantValue;
assertNotComplex(x, 'pad');
var outShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ x.shape[i] + p[1];
}
/* afterPad */
);
var start = paddings.map(function (p) {
return p[0];
});
var xVals = backend.data.get(x.dataId).values;
var xSize = sizeFromShape(x.shape);
var xRank = x.shape.length;
var xStrides = computeStrides(x.shape);
var resultSize = sizeFromShape(outShape);
var resultRank = outShape.length;
var resultStrides = computeStrides(outShape);
var resVals = getTypedArrayFromDType(x.dtype, resultSize);
if (constantValue !== 0) {
resVals.fill(constantValue);
}
for (var i = 0; i < xSize; i++) {
var coords = indexToLoc(i, xRank, xStrides);
var outCoords = coords.map(function (c, i) {
return c + start[i];
});
var outIndex = locToIndex(outCoords, resultRank, resultStrides);
resVals[outIndex] = xVals[i];
}
var outId = backend.write(resVals, outShape, x.dtype);
return {
dataId: outId,
shape: outShape,
dtype: x.dtype
};
}
var padV2Config = {
kernelName: PadV2,
backendName: 'cpu',
kernelFunc: padV2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var powImpl = createSimpleBinaryKernelImpl(function (a, b) {
return Math.pow(a, b);
});
var pow$7 = binaryKernelFunc(Pow, powImpl);
var powConfig = {
kernelName: Pow,
backendName: 'cpu',
kernelFunc: pow$7
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function range$2(args) {
var backend = args.backend,
attrs = args.attrs;
var start = attrs.start,
stop = attrs.stop,
dtype = attrs.dtype,
step = attrs.step;
var values = rangeImpl(start, stop, step, dtype);
return backend.makeTensorInfo([values.length], dtype, values);
}
var rangeConfig = {
kernelName: Range,
backendName: 'cpu',
kernelFunc: range$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var reciprocal$1 = unaryKernelFunc(Reciprocal, function (xi) {
return 1 / xi;
});
var reciprocalConfig = {
kernelName: Reciprocal,
backendName: 'cpu',
kernelFunc: reciprocal$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeBilinear$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images;
var alignCorners = attrs.alignCorners,
halfPixelCenters = attrs.halfPixelCenters,
size = attrs.size;
assertNotComplex(images, 'resizeBilinear');
var imagesStrides = computeStrides(images.shape);
var newHeight = size[0],
newWidth = size[1];
var _images$shape = images.shape,
batch = _images$shape[0],
oldHeight = _images$shape[1],
oldWidth = _images$shape[2],
numChannels = _images$shape[3];
var xValues = backend.data.get(images.dataId).values;
var result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
var effectiveInputSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutputSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
var outputIdx = 0;
var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
for (var b = 0; b < batch; b++) {
for (var r = 0; r < newHeight; r++) {
var sourceFracRow = void 0;
if (halfPixelCenters) {
sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
} else {
sourceFracRow = effectiveRowSizeRatio * r;
}
var sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
var rowFrac = sourceFracRow - sourceRowFloor;
var sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
var topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
var botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
for (var c = 0; c < newWidth; c++) {
var sourceFracCol = void 0;
if (halfPixelCenters) {
sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
} else {
sourceFracCol = effectiveColSizeRatio * c;
}
var sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
var colFrac = sourceFracCol - sourceColFloor;
var sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
var topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
var botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
var topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
var botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
for (var d = 0; d < numChannels; d++) {
// Begin shader.
// Compute the fractional index of the source.
var topLeft = xValues[topLeftOffest + d];
var bottomLeft = xValues[botLeftOffset + d];
var topRight = xValues[topRightOffset + d];
var bottomRight = xValues[botRightOffest + d];
var top = topLeft + (topRight - topLeft) * colFrac;
var bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
var newValue = top + (bottom - top) * rowFrac;
result[outputIdx++] = newValue;
}
}
}
}
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
}
var resizeBilinearConfig = {
kernelName: ResizeBilinear,
backendName: 'cpu',
kernelFunc: resizeBilinear$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeBilinearGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images,
dy = inputs.dy;
var alignCorners = attrs.alignCorners;
assertNotComplex([dy, images], 'resizeBilinearGrad');
var imagesStrides = computeStrides(images.shape);
var _images$shape = images.shape,
batch = _images$shape[0],
xHeight = _images$shape[1],
xWidth = _images$shape[2],
depth = _images$shape[3];
var _dy$shape = dy.shape,
yHeight = _dy$shape[1],
yWidth = _dy$shape[2];
var output = new Float32Array(batch * xHeight * xWidth * depth); // In the backwards pass, we want to find the pixels that were generated
// for each pixel in the input image the forward pass and add the
// corresponding coefficient from dy to the gradient (with some
// interpolation).
var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1]; // Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275
var dyValues = backend.data.get(dy.dataId).values;
var offset = 0;
for (var b = 0; b < batch; b++) {
var bOffset = b * imagesStrides[0];
for (var r = 0; r < yHeight; r++) {
var dxR = r * heightScale;
var topDxRIndex = Math.floor(dxR);
var bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
var topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
var bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
var dxRLerp = dxR - topDxRIndex;
var inverseDxRLerp = 1.0 - dxRLerp;
for (var c = 0; c < yWidth; c++) {
var dxC = c * widthScale;
var leftDxCIndex = Math.floor(dxC);
var rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
var dxCLerp = dxC - leftDxCIndex;
var inverseDxCLerp = 1.0 - dxCLerp;
var topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
var topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
var bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
var bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
var inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
var inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
var dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
var dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
for (var d = 0; d < depth; d++) {
var dyVal = dyValues[offset++];
output[topLeftRCOffset + d] += dyVal * inverseDxRLerpTimesInverseDxCLerp;
output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
}
}
}
}
return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
}
var resizeBilinearGradConfig$1 = {
kernelName: ResizeBilinearGrad,
backendName: 'cpu',
kernelFunc: resizeBilinearGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeNearestNeighbor$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images;
var alignCorners = attrs.alignCorners,
halfPixelCenters = attrs.halfPixelCenters,
size = attrs.size;
assertNotComplex(images, 'resizeNearestNeighbor');
var imagesStrides = computeStrides(images.shape);
var newHeight = size[0],
newWidth = size[1];
var _images$shape = images.shape,
batch = _images$shape[0],
oldHeight = _images$shape[1],
oldWidth = _images$shape[2],
numChannels = _images$shape[3];
var xValues = backend.data.get(images.dataId).values;
var output = new Float32Array(batch * newHeight * newWidth * numChannels);
var effectiveInputSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutputSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
var outputOffset = 0;
for (var b = 0; b < batch; b++) {
var batchOffset = b * imagesStrides[0];
for (var r = 0; r < newHeight; r++) {
var sourceFracRow = halfPixelCenters ? effectiveRowSizeRatio * (r + 0.5) : effectiveRowSizeRatio * r;
var sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
if (halfPixelCenters) {
sourceNearestRow = Math.max(0, sourceNearestRow);
}
var rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
for (var c = 0; c < newWidth; c++) {
var sourceFracCol = halfPixelCenters ? effectiveColSizeRatio * (c + 0.5) : effectiveColSizeRatio * c;
var sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
if (halfPixelCenters) {
sourceNearestCol = Math.max(0, sourceNearestCol);
}
var colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
for (var d = 0; d < numChannels; d++) {
// Begin shader.
// Compute the fractional index of the source.
var newVal = xValues[colOffset + d];
output[outputOffset++] = newVal;
}
}
}
}
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
}
var resizeNearestNeighborConfig = {
kernelName: ResizeNearestNeighbor,
backendName: 'cpu',
kernelFunc: resizeNearestNeighbor$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeNearestNeighborGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images,
dy = inputs.dy;
var alignCorners = attrs.alignCorners;
assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
var imagesStrides = computeStrides(images.shape);
var dyStrides = computeStrides(dy.shape);
var _images$shape = images.shape,
batch = _images$shape[0],
xHeight = _images$shape[1],
xWidth = _images$shape[2],
depth = _images$shape[3];
var _dy$shape = dy.shape,
yHeight = _dy$shape[1],
yWidth = _dy$shape[2];
var output = new Float32Array(batch * xHeight * xWidth * depth);
var dyValues = backend.data.get(dy.dataId).values; // In the backwards pass, we want to find the pixels that were generated
// for each pixel in the input image the forward pass
var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular
// index in dy that we want to search for contributions to dx.
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2; // Loop over the output space.
for (var b = 0; b < batch; b++) {
var batchOffset = b * imagesStrides[0];
for (var r = 0; r < xHeight; r++) {
var rowOffset = batchOffset + r * imagesStrides[1]; // Compute bounds for where in dy we will look
var startRLerp = Math.floor(r * invHeightScale);
var startDyR = Math.floor(startRLerp - winHeight / 2);
for (var c = 0; c < xWidth; c++) {
var colOffset = rowOffset + c * imagesStrides[2]; // Compute bounds for where in dy we will look
var startCLerp = Math.floor(c * invWidthScale);
var startDyC = Math.floor(startCLerp - winWidth / 2);
for (var d = 0; d < depth; d++) {
var accum = 0; // loop over dy
for (var dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
var dyR = dyRIndex + startDyR; // Guard against the window exceeding the bounds of dy
if (dyR < 0 || dyR >= yHeight) {
continue;
}
var dyROffset = batchOffset + dyR * dyStrides[1];
var sourceFracRow = dyR * heightScale;
var sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
if (r !== sourceNearestRow) {
continue;
}
for (var dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
var dyC = dyCIndex + startDyC; // Guard against the window exceeding the bounds of dy
if (dyC < 0 || dyC >= yWidth) {
continue;
}
var dyCOffset = dyROffset + dyC * dyStrides[2];
var sourceFracCol = dyC * widthScale;
var sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : Math.floor(sourceFracCol));
if (c === sourceNearestCol) {
accum += dyValues[dyCOffset + d];
}
}
}
output[colOffset + d] = accum;
}
}
}
}
return backend.makeTensorInfo(images.shape, images.dtype, output);
}
var resizeNearestNeighborGradConfig$1 = {
kernelName: ResizeNearestNeighborGrad,
backendName: 'cpu',
kernelFunc: resizeNearestNeighborGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var dims = attrs.dims;
assertNotComplex(x, 'reverse');
var xRank = x.shape.length;
var $dims = parseAxisParam(dims, x.shape);
if (xRank === 0) {
return identity$1({
inputs: {
x: x
},
backend: backend
});
}
var outBuf = new TensorBuffer(x.shape, x.dtype);
var xBuf = backend.bufferSync(x);
var _loop = function _loop(i) {
var outLoc = outBuf.indexToLoc(i);
var inLoc = outLoc.slice();
$dims.forEach(function (d) {
return inLoc[d] = x.shape[d] - 1 - inLoc[d];
});
outBuf.set.apply(outBuf, [xBuf.get.apply(xBuf, inLoc)].concat(outLoc));
};
for (var i = 0; i < outBuf.size; i++) {
_loop(i);
}
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
var reverseConfig = {
kernelName: Reverse,
backendName: 'cpu',
kernelFunc: reverse$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rotateWithOffsetConfig = {
kernelName: RotateWithOffset,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var image = inputs.image;
var radians = attrs.radians,
fillValue = attrs.fillValue,
center = attrs.center;
var cpuBackend = backend;
var output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
var _image$shape = image.shape,
batch = _image$shape[0],
imageHeight = _image$shape[1],
imageWidth = _image$shape[2],
numChannels = _image$shape[3];
var _backend_util$getImag = getImageCenter(center, imageHeight, imageWidth),
centerX = _backend_util$getImag[0],
centerY = _backend_util$getImag[1];
var fullOpacityValue = 255;
var sinFactor = Math.sin(radians);
var cosFactor = Math.cos(radians);
var imageVals = cpuBackend.data.get(image.dataId).values;
for (var batchIdx = 0; batchIdx < batch; batchIdx++) {
var batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (var row = 0; row < imageHeight; row++) {
var rowOffset = row * (imageWidth * numChannels);
for (var col = 0; col < imageWidth; col++) {
var colOffset = col * numChannels;
for (var channel = 0; channel < numChannels; channel++) {
var coords = [batch, row, col, channel];
var x = coords[2];
var y = coords[1]; // coordX/coordY are the result of rotating and translating x/y.
var coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
var coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
coordX = Math.round(coordX + centerX);
coordY = Math.round(coordY + centerY);
var outputValue = fillValue;
if (typeof fillValue !== 'number') {
if (channel === 3) {
outputValue = fullOpacityValue;
} else {
outputValue = fillValue[channel];
}
} // If the coordinate position falls within the image boundaries...
if (coordX >= 0 && coordX < imageWidth && coordY >= 0 && coordY < imageHeight) {
// set the output to the image value at the coordinate position.
var rotatedRowOffset = coordY * (imageWidth * numChannels);
var rotatedColOffset = coordX * numChannels;
var imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
var outIdx = batchOffset + rowOffset + colOffset + channel;
output[outIdx] = outputValue;
}
}
}
}
var dataId = cpuBackend.write(output, image.shape, image.dtype);
return {
dataId: dataId,
shape: image.shape,
dtype: image.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var round$2 = unaryKernelFunc(Round, function (xi) {
// The algorithm is based on banker's rounding.
var base = Math.floor(xi);
if (xi - base < 0.5) {
return Math.floor(xi);
} else if (xi - base > 0.5) {
return Math.ceil(xi);
} else {
if (base % 2.0 === 0.0) {
return base;
} else {
return base + 1.0;
}
}
});
var roundConfig = {
kernelName: Round,
backendName: 'cpu',
kernelFunc: round$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
var flattenShape = [outputSize / sliceSize, sliceSize];
var indicesData = indices.values;
var updatesData = updates.values;
if (outputSize === 0) {
return buffer(shape, updates.dtype);
}
var outBuf = buffer(flattenShape, updates.dtype);
outBuf.values.fill(defaultValue);
for (var i = 0; i < numUpdates; i++) {
var index = [];
var flattenIndex = 0;
for (var j = 0; j < sliceRank; j++) {
var dim = indicesData[i * sliceRank + j];
index.push(dim);
flattenIndex += dim * strides[j];
}
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
throw new Error("Invalid indices: " + index + " does not index into " + shape);
}
for (var k = 0; k < sliceSize; k++) {
if (sumDupeIndices) {
outBuf.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k];
} else {
outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k];
}
}
}
return outBuf;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function scatterNd(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var indices = inputs.indices,
updates = inputs.updates;
var shape = attrs.shape;
var _backend_util$calcula = calculateShapes(updates, indices, shape),
sliceRank = _backend_util$calcula.sliceRank,
numUpdates = _backend_util$calcula.numUpdates,
sliceSize = _backend_util$calcula.sliceSize,
strides = _backend_util$calcula.strides,
outputSize = _backend_util$calcula.outputSize;
var sumDupeIndices = true;
var indicesBuf = backend.bufferSync(indices);
var updatesBuf = backend.bufferSync(updates);
var outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0
/* defaultValue */
, sumDupeIndices);
return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
}
var scatterNdConfig = {
kernelName: ScatterNd,
backendName: 'cpu',
kernelFunc: scatterNd
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function select$1(args) {
var inputs = args.inputs,
backend = args.backend;
var condition = inputs.condition,
t = inputs.t,
e = inputs.e;
assertNotComplex([condition, t, e], 'select');
var conditionRank = condition.shape.length;
var values = backend.data.get(condition.dataId).values;
var tValues = backend.data.get(t.dataId).values;
var eValues = backend.data.get(e.dataId).values;
var resultDtype = upcastType(t.dtype, e.dtype);
var newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
var index = 0;
var offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ? 1 : sizeFromShape(t.shape.slice(1));
for (var i = 0; i < values.length; i++) {
for (var j = 0; j < offset; j++) {
if (values[i] === 1) {
newValues[index++] = tValues[i];
} else {
newValues[index++] = eValues[i];
}
}
}
return backend.makeTensorInfo(t.shape, resultDtype, newValues);
}
var selectConfig = {
kernelName: Select,
backendName: 'cpu',
kernelFunc: select$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var scaleAlpha = SELU_SCALEALPHA;
var scale = SELU_SCALE;
var selu$1 = unaryKernelFunc(Selu, function (xi) {
if (xi >= 0) {
return scale * xi;
} else {
return scaleAlpha * (Math.exp(xi) - 1);
}
});
var seluConfig = {
kernelName: Selu,
backendName: 'cpu',
kernelFunc: selu$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sign$2 = unaryKernelFunc(Sign, function (xi) {
if (xi < 0) {
return -1;
} else if (xi > 0) {
return 1;
} else {
return 0;
}
});
var signConfig = {
kernelName: Sign,
backendName: 'cpu',
kernelFunc: sign$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sin$1 = unaryKernelFunc(Sin, function (xi) {
return Math.sin(xi);
});
var sinConfig = {
kernelName: Sin,
backendName: 'cpu',
kernelFunc: sin$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var sinh$1 = unaryKernelFunc(Sinh, function (xi) {
return Math.sinh(xi);
});
var sinhConfig = {
kernelName: Sinh,
backendName: 'cpu',
kernelFunc: sinh$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// epsilon is the difference between 1.0 and the next representable float.
// For a single precision 32 bit float this should be 2^-23, see:
// https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm
var epsilon$1 = 1.1920928955078125e-7;
var threshold$1 = Math.log(epsilon$1) + 2.0;
var softplus$1 = unaryKernelFunc(Softplus, function (xi) {
// Value above which exp(x) may overflow, but softplus(x) == x
// is within machine epsilon.
var tooLarge = xi > -threshold$1; // Value below which exp(x) may underflow, but softplus(x) == exp(x)
// is within machine epsilon.
var tooSmall = xi < threshold$1;
var expX = Math.exp(xi);
var result;
if (tooSmall) {
result = expX;
} else if (tooLarge) {
result = xi;
} else {
result = Math.log(1.0 + expX);
}
return result;
});
var softplusConfig = {
kernelName: Softplus,
backendName: 'cpu',
kernelFunc: softplus$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function spaceToBatchND$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockShape = attrs.blockShape,
paddings = attrs.paddings;
assertNotComplex([x], 'spaceToBatchND');
var prod = sizeFromShape(blockShape);
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var paddedX = padV2Config.kernelFunc({
inputs: {
x: x
},
backend: backend,
attrs: {
paddings: completePaddings,
constantValue: 0
}
});
var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
var reshapeInputs = {
x: paddedX
};
var reshapeAttrs = {
shape: reshapedPaddedShape
};
var paddedXReshaped = reshape$2({
inputs: reshapeInputs,
backend: backend,
attrs: reshapeAttrs
});
var transposeInputs = {
x: paddedXReshaped
};
var transposeAttrs = {
perm: permutedReshapedPaddedPermutation
};
var paddedXT = transpose$1({
inputs: transposeInputs,
backend: backend,
attrs: transposeAttrs
});
var resultReshapeInputs = {
x: paddedXT
};
var resultReshapeAttrs = {
shape: flattenShape
};
var result = reshape$2({
inputs: resultReshapeInputs,
backend: backend,
attrs: resultReshapeAttrs
});
backend.disposeIntermediateTensorInfo(paddedX);
backend.disposeIntermediateTensorInfo(paddedXReshaped);
backend.disposeIntermediateTensorInfo(paddedXT);
return result;
}
var spaceToBatchNDConfig = {
kernelName: SpaceToBatchND,
backendName: 'cpu',
kernelFunc: spaceToBatchND$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseFillEmptyRows$1(args) {
var inputs = args.inputs,
backend = args.backend;
var indices = inputs.indices,
values = inputs.values,
denseShape = inputs.denseShape,
defaultValue = inputs.defaultValue;
if (denseShape.shape.length !== 1) {
throw new Error("Dense shape must be a vector, saw:\n " + denseShape.shape);
}
if (indices.shape.length !== 2) {
throw new Error("Indices must be a matrix, saw:\n " + indices.shape);
}
if (values.shape.length !== 1) {
throw new Error("Values must be a vector, saw:\n " + values.shape);
}
if (defaultValue.shape.length !== 0) {
throw new Error("Default value must be a scalar, saw:\n " + defaultValue.shape);
}
var $indices = backend.data.get(indices.dataId).values;
var $values = backend.data.get(values.dataId).values;
var $denseShape = backend.data.get(denseShape.dataId).values;
var $defaultValue = backend.data.get(defaultValue.dataId).values[0];
var _sparseFillEmptyRowsI = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue),
outputIndices = _sparseFillEmptyRowsI[0],
outputIndicesShape = _sparseFillEmptyRowsI[1],
outputValues = _sparseFillEmptyRowsI[2],
emptyRowIndicator = _sparseFillEmptyRowsI[3],
reverseIndexMap = _sparseFillEmptyRowsI[4];
return [backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) {
return Number(value);
}))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap))];
}
var sparseFillEmptyRowsConfig = {
kernelName: SparseFillEmptyRows,
backendName: 'cpu',
kernelFunc: sparseFillEmptyRows$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseReshape$1(args) {
var inputs = args.inputs,
backend = args.backend;
var inputIndices = inputs.inputIndices,
inputShape = inputs.inputShape,
newShape = inputs.newShape;
if (inputIndices.shape.length !== 2) {
throw new Error("Input indices should be a matrix but received shape\n " + inputIndices.shape);
}
if (inputShape.shape.length !== 1) {
throw new Error("Input shape should be a vector but received shape\n " + inputShape.shape);
}
if (newShape.shape.length !== 1) {
throw new Error("Target shape should be a vector but received shape " + newShape.shape);
}
var $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
var $inputIndices = backend.data.get(inputIndices.dataId).values;
var targetShape = Array.from(backend.data.get(newShape.dataId).values);
var _sparseReshapeImpl = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape),
newIndices = _sparseReshapeImpl[0],
indicesShape = _sparseReshapeImpl[1],
outputShape = _sparseReshapeImpl[2];
return [backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape))];
}
var sparseReshapeConfig = {
kernelName: SparseReshape,
backendName: 'cpu',
kernelFunc: sparseReshape$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseSegmentMean$1(args) {
var inputs = args.inputs,
backend = args.backend;
var data = inputs.data,
indices = inputs.indices,
segmentIds = inputs.segmentIds;
if (data.shape.length < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if (indices.shape.length !== 1) {
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
}
if (segmentIds.shape.length !== 1) {
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
}
var $data = backend.data.get(data.dataId).values;
var $indices = backend.data.get(indices.dataId).values;
var $segmentIds = backend.data.get(segmentIds.dataId).values;
var _sparseSegmentReducti = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true),
outputData = _sparseSegmentReducti[0],
outputDataShape = _sparseSegmentReducti[1];
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
var sparseSegmentMeanConfig = {
kernelName: SparseSegmentMean,
backendName: 'cpu',
kernelFunc: sparseSegmentMean$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseSegmentSum$1(args) {
var inputs = args.inputs,
backend = args.backend;
var data = inputs.data,
indices = inputs.indices,
segmentIds = inputs.segmentIds;
if (data.shape.length < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if (indices.shape.length !== 1) {
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
}
if (segmentIds.shape.length !== 1) {
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
}
var $data = backend.data.get(data.dataId).values;
var $indices = backend.data.get(indices.dataId).values;
var $segmentIds = backend.data.get(segmentIds.dataId).values;
var _sparseSegmentReducti = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds),
outputData = _sparseSegmentReducti[0],
outputDataShape = _sparseSegmentReducti[1];
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
var sparseSegmentSumConfig = {
kernelName: SparseSegmentSum,
backendName: 'cpu',
kernelFunc: sparseSegmentSum$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseToDense$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var sparseIndices = inputs.sparseIndices,
sparseValues = inputs.sparseValues,
defaultValue = inputs.defaultValue;
var outputShape = attrs.outputShape;
var _backend_util$calcula = calculateShapes(sparseValues, sparseIndices, outputShape),
sliceRank = _backend_util$calcula.sliceRank,
numUpdates = _backend_util$calcula.numUpdates,
sliceSize = _backend_util$calcula.sliceSize,
strides = _backend_util$calcula.strides,
outputSize = _backend_util$calcula.outputSize;
var sumDupeIndices = false;
var indicesBuf = backend.bufferSync(sparseIndices);
var updatesBuf = backend.bufferSync(sparseValues);
var $defaultValue = backend.data.get(defaultValue.dataId).values[0];
var outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
}
var sparseToDenseConfig = {
kernelName: SparseToDense,
backendName: 'cpu',
kernelFunc: sparseToDense$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function splitV(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var numOrSizeSplits = attrs.numOrSizeSplits,
axis = attrs.axis;
var $axis = parseAxisParam(axis, x.shape)[0];
var splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
var begin = new Array(x.shape.length).fill(0);
var size = x.shape.slice();
return splitSizes.map(function (s) {
var sliceSize = [].concat(size);
sliceSize[$axis] = s;
var sliceT = slice$3({
inputs: {
x: x
},
backend: backend,
attrs: {
begin: begin,
size: sliceSize
}
});
begin[$axis] += s;
return sliceT;
});
}
var splitVConfig = {
kernelName: SplitV,
backendName: 'cpu',
kernelFunc: splitV
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var squareConfig = {
kernelName: Square,
backendName: 'cpu',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend;
var x = inputs.x;
var cpuBackend = backend;
assertNotComplex(x, 'square');
var values = cpuBackend.data.get(x.dataId).values;
var newValues = new Float32Array(values.length);
for (var i = 0; i < values.length; ++i) {
var value = values[i];
newValues[i] = value * value;
}
var dataId = cpuBackend.write(newValues, x.shape, x.dtype);
return {
dataId: dataId,
shape: x.shape,
dtype: x.dtype
};
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var step$1 = unaryKernelFunc(Step, function (xi, attrs) {
var stepAttrs = attrs;
if (isNaN(xi)) {
return NaN;
} else {
return xi > 0 ? 1 : stepAttrs.alpha;
}
});
var stepConfig = {
kernelName: Step,
backendName: 'cpu',
kernelFunc: step$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stridedSlice$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var begin = attrs.begin,
end = attrs.end,
strides = attrs.strides,
beginMask = attrs.beginMask,
endMask = attrs.endMask,
ellipsisMask = attrs.ellipsisMask,
newAxisMask = attrs.newAxisMask,
shrinkAxisMask = attrs.shrinkAxisMask;
assertNotComplex(x, 'stridedSlice');
var _slice_util$sliceInfo = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask),
nonStrided = _slice_util$sliceInfo.nonStrided,
$begin = _slice_util$sliceInfo.$begin,
$strides = _slice_util$sliceInfo.$strides,
size = _slice_util$sliceInfo.size,
newShape = _slice_util$sliceInfo.newShape,
outShape = _slice_util$sliceInfo.outShape;
var $x = reshape$2({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: newShape
}
});
var result;
if (nonStrided) {
var sliced = slice$3({
inputs: {
x: $x
},
backend: backend,
attrs: {
begin: $begin,
size: size
}
});
result = reshape$2({
inputs: {
x: sliced
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo(sliced);
} else if (outShape.some(function (axis) {
return axis === 0;
})) {
result = backend.makeTensorInfo(outShape, x.dtype, []);
} else {
var xBuf = backend.bufferSync($x);
var outBuf = stridedSliceImpl(outShape, xBuf, $strides, $begin);
result = backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
var resultReshaped = reshape$2({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo($x);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var stridedSliceConfig = {
kernelName: StridedSlice,
backendName: 'cpu',
kernelFunc: stridedSlice$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringNGrams$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var separator = attrs.separator,
nGramWidths = attrs.nGramWidths,
leftPad = attrs.leftPad,
rightPad = attrs.rightPad,
padWidth = attrs.padWidth,
preserveShortSequences = attrs.preserveShortSequences;
var data = inputs.data,
dataSplits = inputs.dataSplits;
var $data = backend.data.get(data.dataId).values;
var $dataSplits = backend.data.get(dataSplits.dataId).values;
var _stringNGramsImpl = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences),
nGrams = _stringNGramsImpl[0],
nGramsSplits = _stringNGramsImpl[1];
return [backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits)];
}
var stringNGramsConfig = {
kernelName: StringNGrams,
backendName: 'cpu',
kernelFunc: stringNGrams$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringSplit$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var skipEmpty = attrs.skipEmpty;
var input = inputs.input,
delimiter = inputs.delimiter;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (input.shape.length !== 1) {
throw new Error("Input must be a vector, got shape: " + input.shape);
}
if (delimiter.shape.length !== 0) {
throw new Error("Delimiter must be a scalar, got shape: " + delimiter.shape);
}
var $input = backend.data.get(input.dataId).values;
var $delimiter = backend.data.get(delimiter.dataId).values[0];
var _stringSplitImpl = stringSplitImpl($input, $delimiter, skipEmpty),
indices = _stringSplitImpl[0],
values = _stringSplitImpl[1],
shape = _stringSplitImpl[2];
var outputSize = values.length;
return [backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape))];
}
var stringSplitConfig = {
kernelName: StringSplit,
backendName: 'cpu',
kernelFunc: stringSplit$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringToHashBucketFast$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var numBuckets = attrs.numBuckets;
var input = inputs.input;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (numBuckets <= 0) {
throw new Error("Number of buckets must be at least 1");
}
var $input = backend.data.get(input.dataId).values;
var output = stringToHashBucketFastImpl($input, numBuckets);
return backend.makeTensorInfo(input.shape, 'int32', output);
}
var stringToHashBucketFastConfig = {
kernelName: StringToHashBucketFast,
backendName: 'cpu',
kernelFunc: stringToHashBucketFast$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tan$1 = unaryKernelFunc(Tan, function (xi) {
return Math.tan(xi);
});
var tanConfig = {
kernelName: Tan,
backendName: 'cpu',
kernelFunc: tan$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var tanh$2 = unaryKernelFunc(Tanh, function (xi) {
return Math.tanh(xi);
});
var tanhConfig = {
kernelName: Tanh,
backendName: 'cpu',
kernelFunc: tanh$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tile$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var reps = attrs.reps;
assertNotComplex(x, 'tile');
var outBuf = tileImpl(backend.bufferSync(x), reps);
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
var tileConfig = {
kernelName: Tile,
backendName: 'cpu',
kernelFunc: tile$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function topK(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var k = attrs.k,
sorted = attrs.sorted;
assertNotComplex(x, 'topk');
var xVals = backend.data.get(x.dataId).values;
var _topKImpl = topKImpl(xVals, x.shape, x.dtype, k, sorted),
allTopKVals = _topKImpl[0],
allTopKIndices = _topKImpl[1];
return [backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)];
}
var topKConfig = {
kernelName: TopK,
backendName: 'cpu',
kernelFunc: topK
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transform$1(args) {
var inputs = args.inputs,
attrs = args.attrs,
backend = args.backend;
var image = inputs.image,
transforms = inputs.transforms;
var interpolation = attrs.interpolation,
fillMode = attrs.fillMode,
fillValue = attrs.fillValue,
outputShape = attrs.outputShape;
var _image$shape = image.shape,
batch = _image$shape[0],
imageHeight = _image$shape[1],
imageWidth = _image$shape[2],
numChannels = _image$shape[3];
var _ref = outputShape != null ? outputShape : [imageHeight, imageWidth],
outHeight = _ref[0],
outWidth = _ref[1];
var outShape = [batch, outHeight, outWidth, numChannels];
var strides = computeStrides(image.shape);
var batchStride = strides[0];
var rowStride = strides[1];
var colStride = strides[2];
var outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
outVals.fill(fillValue);
var imageVals = backend.data.get(image.dataId).values;
var transformVals = backend.data.get(transforms.dataId).values; // Ref TF implementation:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/image/image_ops.h
for (var b = 0; b < batch; ++b) {
var _transform = transforms.shape[0] === 1 ? transformVals : transformVals.subarray(b * 8, b * 8 + 8);
for (var outY = 0; outY < outHeight; ++outY) {
for (var outX = 0; outX < outWidth; ++outX) {
for (var channel = 0; channel < numChannels; ++channel) {
var val = void 0;
var projection = _transform[6] * outX + _transform[7] * outY + 1;
if (projection === 0) {
// Return the fill value for infinite coordinates,
// which are outside the input image
continue;
}
var inX = (_transform[0] * outX + _transform[1] * outY + _transform[2]) / projection;
var inY = (_transform[3] * outX + _transform[4] * outY + _transform[5]) / projection;
var x = mapCoord(inX, imageWidth, fillMode);
var y = mapCoord(inY, imageHeight, fillMode);
switch (interpolation) {
case 'nearest':
val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, b, y, x, channel, fillValue);
break;
case 'bilinear':
val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, b, y, x, channel, fillValue);
break;
default:
throw new Error("Error in Transform: Expect 'nearest' or " + ("'bilinear', but got " + interpolation));
}
var ind = b * batchStride + outY * rowStride + outX * colStride + channel;
outVals[ind] = val;
}
}
}
return backend.makeTensorInfo(outShape, image.dtype, outVals);
}
var dataId = backend.write(outVals, outShape, image.dtype);
return {
dataId: dataId,
shape: image.shape,
dtype: image.dtype
};
}
var transformConfig = {
kernelName: Transform,
backendName: 'cpu',
kernelFunc: transform$1
};
function mapCoord(outCoord, len, mode) {
switch (mode) {
case 'reflect':
return mapCoordReflect(outCoord, len);
case 'wrap':
return mapCoordWrap(outCoord, len);
case 'nearest':
return mapCoordNearest(outCoord, len);
case 'constant':
default:
return mapCoordConstant(outCoord, len);
}
}
function mapCoordReflect(outCoord, len) {
// Reflect [abcd] to [dcba|abcd|dcba].
var inCoord = outCoord;
if (inCoord < 0) {
if (len <= 1) {
inCoord = 0;
} else {
var sz2 = 2 * len;
if (inCoord < sz2) {
inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
}
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
}
} else if (inCoord > len - 1) {
if (len <= 1) {
inCoord = 0;
} else {
var _sz = 2 * len;
inCoord -= _sz * Math.trunc(inCoord / _sz);
if (inCoord >= len) {
inCoord = _sz - inCoord - 1;
}
}
} // clamp is necessary because when outCoord = 3.5 and len = 4,
// inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
return clamp(0, inCoord, len - 1);
}
function mapCoordWrap(outCoord, len) {
// Wrap [abcd] to [abcd|abcd|abcd].
var inCoord = outCoord;
if (inCoord < 0) {
if (len <= 1) {
inCoord = 0;
} else {
var sz = len - 1;
inCoord += len * (Math.trunc(-inCoord / sz) + 1);
}
} else if (inCoord > len - 1) {
if (len <= 1) {
inCoord = 0;
} else {
var _sz2 = len - 1;
inCoord -= len * Math.trunc(inCoord / _sz2);
}
} // clamp is necessary because when outCoord = -0.5 and len = 4,
// inCoord = 3.5 and will be rounded to 4 in nearest interpolation.
return clamp(0, inCoord, len - 1);
}
function mapCoordConstant(outCoord, len) {
return outCoord;
}
function mapCoordNearest(outCoord, len) {
return clamp(0, outCoord, len - 1);
}
function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
var ind = batch * batchStride + y * rowStride + x * colStride + channel;
if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
return imageVals[ind];
} else {
return fillValue;
}
}
function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
var $y = Math.round(y);
var $x = Math.round(x);
return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
}
function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
var yFloor = Math.floor(y);
var xFloor = Math.floor(x);
var yCeil = yFloor + 1;
var xCeil = xFloor + 1; // f(x, yFloor) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yFloor)
// + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yFloor)
var valueYFloor = (xCeil - x) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) + (x - xFloor) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue); // f(x, yCeil) = (xCeil - x) / (xCeil - xFloor) * f(xFloor, yCeil)
// + (x - xFloor) / (xCeil - xFloor) * f(xCeil, yCeil)
var valueYCeil = (xCeil - x) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) + (x - xFloor) * readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue); // f(x, y) = (yCeil - y) / (yCeil - yFloor) * f(x, yFloor)
// + (y - yFloor) / (yCeil - yFloor) * f(x, yCeil)
return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unique$2(args) {
var inputs = args.inputs,
attrs = args.attrs,
backend = args.backend;
var axis = attrs.axis;
var x = inputs.x;
assertNotComplex(x, 'unique');
var values = backend.data.get(x.dataId).values;
var _uniqueImpl = uniqueImpl(values, axis, x.shape, x.dtype),
outputValues = _uniqueImpl.outputValues,
outputShape = _uniqueImpl.outputShape,
indices = _uniqueImpl.indices;
return [backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices)];
}
var uniqueConfig = {
kernelName: Unique,
backendName: 'cpu',
kernelFunc: unique$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unpack$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var value = inputs.value;
var axis = attrs.axis;
if (axis < 0) {
axis += value.shape.length;
}
var valueRank = value.shape.length;
var num = value.shape[axis];
var outShape = new Array(valueRank - 1);
var outIndex = 0;
for (var i = 0; i < valueRank; i++) {
if (i !== axis) {
outShape[outIndex++] = value.shape[i];
}
}
var begin = new Array(valueRank).fill(0);
var size = value.shape.slice();
size[axis] = 1;
var res = new Array(num);
for (var _i = 0; _i < res.length; _i++) {
begin[axis] = _i;
var tempRes = slice$3({
inputs: {
x: value
},
backend: backend,
attrs: {
begin: begin,
size: size
}
});
res[_i] = reshape$2({
inputs: {
x: tempRes
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo(tempRes);
}
return res;
}
var unpackConfig = {
kernelName: Unpack,
backendName: 'cpu',
kernelFunc: unpack$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unsortedSegmentSum$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
segmentIds = inputs.segmentIds;
var numSegments = attrs.numSegments;
assertNotComplex(x, 'unsortedSegmentSum');
var xRank = x.shape.length;
var segmentIdsRank = segmentIds.shape.length;
var res = [];
var intermediates = []; // Reshape the segment id's so that they can be broadcast with
// x. The new shape should be [segmentIds.shape, 1, ..., 1]
var numIters = xRank - segmentIdsRank;
var $segmentIds = segmentIds;
for (var i = 0; i < numIters; ++i) {
var expanded = expandDims$2({
inputs: {
input: $segmentIds
},
backend: backend,
attrs: {
dim: i + 1
}
});
$segmentIds = expanded;
intermediates.push(expanded);
}
for (var _i = 0; _i < numSegments; ++_i) {
var scalarValue = createScalarValue(_i, 'int32');
var segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
var mask = equal$1({
inputs: {
a: segmentId,
b: $segmentIds
},
backend: backend
});
var maskCasted = cast$2({
inputs: {
x: mask
},
backend: backend,
attrs: {
dtype: 'float32'
}
});
var mul = multiply$3({
inputs: {
a: maskCasted,
b: x
},
backend: backend
});
var sumTensorInfo = sum$3({
inputs: {
x: mul
},
backend: backend,
attrs: {
axis: 0,
keepDims: false
}
});
res.push(sumTensorInfo);
intermediates.push(segmentId);
intermediates.push(mask);
intermediates.push(maskCasted);
intermediates.push(mul);
intermediates.push(sumTensorInfo);
}
var result = pack$1({
inputs: res,
backend: backend,
attrs: {
axis: 0
}
});
intermediates.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var unsortedSegmentSumConfig = {
kernelName: UnsortedSegmentSum,
backendName: 'cpu',
kernelFunc: unsortedSegmentSum$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelConfigs = [_fusedMatMulConfig, absConfig, acosConfig, acoshConfig, addConfig, addNConfig, allConfig, anyConfig, argMaxConfig, argMinConfig, asinConfig, asinhConfig, atanConfig, atan2Config, atanhConfig, avgPoolConfig, avgPool3DConfig, avgPool3DGradConfig$1, avgPoolGradConfig$1, batchMatMulConfig, batchNormConfig, batchToSpaceNDConfig, bincountConfig, broadcastArgsConfig, castConfig, ceilConfig, clipConfig, complexConfig, complexAbsConfig, concatConfig, conv2DBackpropFilterConfig, conv2DBackpropInputConfig, conv2DConfig, conv3DBackpropFilterV2Config, conv3DBackpropInputV2Config, conv3DConfig, cosConfig, coshConfig, cropAndResizeConfig, cumsumConfig, denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, depthwiseConv2dNativeBackpropFilterConfig, depthwiseConv2dNativeBackpropInputConfig, diagConfig, dilation2dConfig, dilation2dBackpropInputConfig, dilation2dBackpropFilterConfig, realDivConfig, einsumConfig, eluConfig, eluGradConfig$1, equalConfig, erfConfig, expConfig, expandDimsConfig, expm1Config, fftConfig, fillConfig, flipLeftRightConfig, floorConfig, floorDivConfig, fusedConv2DConfig, fusedDepthwiseConv2DConfig, gatherNdConfig, gatherV2Config, greaterConfig, greaterEqualConfig, identityConfig, ifftConfig, imagConfig, isFiniteConfig, isInfConfig, isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, linSpaceConfig, logConfig, log1pConfig, logicalAndConfig, logicalNotConfig, logicalOrConfig, lRNConfig, lRNGradConfig, maximumConfig, maxPoolConfig, maxPool3DConfig, maxPool3DGradConfig$1, maxPoolGradConfig$1, maxPoolWithArgmaxConfig, maxConfig, meanConfig, minConfig, minimumConfig, mirrorPadConfig, modConfig, multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, notEqualConfig, oneHotConfig, onesLikeConfig, packConfig, padV2Config, powConfig, preluConfig, prodConfig, rangeConfig, realConfig, reciprocalConfig, reluConfig, relu6Config, reshapeConfig, resizeBilinearConfig, resizeBilinearGradConfig$1, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig$1, reverseConfig, rotateWithOffsetConfig, roundConfig, rsqrtConfig, scatterNdConfig, selectConfig, seluConfig, sigmoidConfig, signConfig, sinConfig, sinhConfig, sliceConfig, softmaxConfig, softplusConfig, spaceToBatchNDConfig, sparseFillEmptyRowsConfig, sparseReshapeConfig, sparseSegmentMeanConfig, sparseSegmentSumConfig, sparseToDenseConfig, splitVConfig, sqrtConfig, squareConfig, squaredDifferenceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, stringSplitConfig, stringToHashBucketFastConfig, subConfig, sumConfig, tanConfig, tanhConfig, tileConfig, topKConfig, transposeConfig, transformConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, zerosLikeConfig];
for (var _i$1 = 0, _kernelConfigs = kernelConfigs; _i$1 < _kernelConfigs.length; _i$1++) {
var kernelConfig = _kernelConfigs[_i$1];
registerKernel(kernelConfig);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var contexts = {};
var WEBGL_ATTRIBUTES = {
alpha: false,
antialias: false,
premultipliedAlpha: false,
preserveDrawingBuffer: false,
depth: false,
stencil: false,
failIfMajorPerformanceCaveat: true
};
function clearWebGLContext(webGLVersion) {
delete contexts[webGLVersion];
}
function setWebGLContext(webGLVersion, gl) {
contexts[webGLVersion] = gl;
}
function getWebGLContext(webGLVersion) {
if (!(webGLVersion in contexts)) {
var newCtx = getWebGLRenderingContext(webGLVersion);
if (newCtx !== null) {
contexts[webGLVersion] = newCtx;
} else {
console.log('Could not get context for WebGL version', webGLVersion);
return null;
}
}
var gl = contexts[webGLVersion];
if (gl.isContextLost()) {
delete contexts[webGLVersion];
return getWebGLContext(webGLVersion);
}
gl.disable(gl.DEPTH_TEST);
gl.disable(gl.STENCIL_TEST);
gl.disable(gl.BLEND);
gl.disable(gl.DITHER);
gl.disable(gl.POLYGON_OFFSET_FILL);
gl.disable(gl.SAMPLE_COVERAGE);
gl.enable(gl.SCISSOR_TEST);
gl.enable(gl.CULL_FACE);
gl.cullFace(gl.BACK);
return contexts[webGLVersion];
}
function createCanvas(webGLVersion) {
if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
return new OffscreenCanvas(300, 150);
} else if (typeof document !== 'undefined') {
return document.createElement('canvas');
} else {
throw new Error('Cannot create a canvas in this context');
}
}
function getWebGLRenderingContext(webGLVersion) {
if (webGLVersion !== 1 && webGLVersion !== 2) {
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
}
var canvas = createCanvas(webGLVersion);
canvas.addEventListener('webglcontextlost', function (ev) {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
if (webGLVersion === 1) {
return canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES);
}
return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PackingScheme;
(function (PackingScheme) {
/**
* All values in a single texel are densely packed without any constraints.
*
* This is how the shader encodes a tensor with shape = [2, 3, 4]
* (indices are [batch, row, col]).
*
* 000|001 010|011 020|021
* ------- ------- -------
* 002|003 012|013 022|023
*
* 100|101 110|111 120|121
* ------- ------- -------
* 102|103 112|113 122|123
*
*/
PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
/**
* Single texels contain only values from the same batch, and from adjacent
* rows and columns.
*
* This is how the shader encodes a tensor with shape = [2, 3, 5]
* (indices are [batch, row, col]).
*
* 000|001 002|003 004|xxx 020|021 022|023 024|xxx
* ------- ------- ------- ------- ------- -------
* 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
*
* 100|101 102|103 104|xxx 120|121 122|123 124|xxx
* ------- ------- ------- ------- ------- -------
* 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
*
*/
PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
})(PackingScheme || (PackingScheme = {}));
var TextureUsage;
(function (TextureUsage) {
TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
})(TextureUsage || (TextureUsage = {}));
var PhysicalTextureType;
(function (PhysicalTextureType) {
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
})(PhysicalTextureType || (PhysicalTextureType = {}));
function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
return [columns, rows];
}
function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
return matrixSize * channelsPerTexture;
}
function getColorMatrixTextureShapeWidthHeight(rows, columns) {
return [columns * 4, rows];
}
/**
* Get shape for densely packed RGBA texture.
*/
function getDenseTexShape(shape) {
var size = sizeFromShape(shape);
var texelsNeeded = Math.ceil(size / 4);
return sizeToSquarishShape(texelsNeeded);
}
function getMatrixSizeFromUnpackedArraySize(unpackedSize, channelsPerTexture) {
if (unpackedSize % channelsPerTexture !== 0) {
throw new Error("unpackedSize (" + unpackedSize + ") must be a multiple of " + ("" + channelsPerTexture));
}
return unpackedSize / channelsPerTexture;
}
function decodeMatrixFromUnpackedColorRGBAArray(unpackedArray, matrix, channels) {
var requiredSize = unpackedArray.length * channels / 4;
if (matrix.length < requiredSize) {
throw new Error("matrix length (" + matrix.length + ") must be >= " + requiredSize);
}
var dst = 0;
for (var src = 0; src < unpackedArray.length; src += 4) {
for (var c = 0; c < channels; c++) {
matrix[dst++] = unpackedArray[src + c];
}
}
}
function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
return [Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))];
}
function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
var _getPackedMatrixTextu = getPackedMatrixTextureShapeWidthHeight(rows, columns),
w = _getPackedMatrixTextu[0],
h = _getPackedMatrixTextu[1];
return w * h * 4;
}
function getTextureConfig( // tslint:disable-next-line:no-any
gl, textureHalfFloatExtension) {
// tslint:disable-next-line:no-any
var glany = gl;
var internalFormatFloat;
var internalFormatHalfFloat;
var internalFormatPackedHalfFloat;
var internalFormatPackedFloat;
var textureFormatFloat;
var downloadTextureFormat;
var downloadUnpackNumChannels;
var defaultNumChannels;
var textureTypeHalfFloat;
var textureTypeFloat;
if (env().getNumber('WEBGL_VERSION') === 2) {
internalFormatFloat = glany.R32F;
internalFormatHalfFloat = glany.R16F;
internalFormatPackedHalfFloat = glany.RGBA16F;
internalFormatPackedFloat = glany.RGBA32F;
textureFormatFloat = glany.RED;
downloadUnpackNumChannels = 4;
defaultNumChannels = 1;
textureTypeHalfFloat = glany.HALF_FLOAT;
textureTypeFloat = glany.FLOAT;
} else {
internalFormatFloat = gl.RGBA;
internalFormatHalfFloat = gl.RGBA;
internalFormatPackedHalfFloat = gl.RGBA;
internalFormatPackedFloat = glany.RGBA;
textureFormatFloat = gl.RGBA;
downloadUnpackNumChannels = 4;
defaultNumChannels = 4;
textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null;
textureTypeFloat = gl.FLOAT;
}
downloadTextureFormat = gl.RGBA;
return {
internalFormatFloat: internalFormatFloat,
internalFormatHalfFloat: internalFormatHalfFloat,
internalFormatPackedHalfFloat: internalFormatPackedHalfFloat,
internalFormatPackedFloat: internalFormatPackedFloat,
textureFormatFloat: textureFormatFloat,
downloadTextureFormat: downloadTextureFormat,
downloadUnpackNumChannels: downloadUnpackNumChannels,
defaultNumChannels: defaultNumChannels,
textureTypeHalfFloat: textureTypeHalfFloat,
textureTypeFloat: textureTypeFloat
};
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function callAndCheck(gl, func) {
var returnValue = func();
if (env().getBool('DEBUG')) {
checkWebGLError(gl);
}
return returnValue;
}
function checkWebGLError(gl) {
var error = gl.getError();
if (error !== gl.NO_ERROR) {
throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
}
} // https://en.wikipedia.org/wiki/Half-precision_floating-point_format
var MIN_FLOAT16 = 5.96e-8;
var MAX_FLOAT16 = 65504;
function canBeRepresented(num) {
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16) {
return true;
}
return false;
}
function getWebGLErrorMessage(gl, status) {
switch (status) {
case gl.NO_ERROR:
return 'NO_ERROR';
case gl.INVALID_ENUM:
return 'INVALID_ENUM';
case gl.INVALID_VALUE:
return 'INVALID_VALUE';
case gl.INVALID_OPERATION:
return 'INVALID_OPERATION';
case gl.INVALID_FRAMEBUFFER_OPERATION:
return 'INVALID_FRAMEBUFFER_OPERATION';
case gl.OUT_OF_MEMORY:
return 'OUT_OF_MEMORY';
case gl.CONTEXT_LOST_WEBGL:
return 'CONTEXT_LOST_WEBGL';
default:
return "Unknown error code " + status;
}
}
function getExtensionOrThrow(gl, extensionName) {
return throwIfNull(gl, function () {
return gl.getExtension(extensionName);
}, 'Extension "' + extensionName + '" not supported on this browser.');
}
function createVertexShader(gl, vertexShaderSource) {
var vertexShader = throwIfNull(gl, function () {
return gl.createShader(gl.VERTEX_SHADER);
}, 'Unable to create vertex WebGLShader.');
callAndCheck(gl, function () {
return gl.shaderSource(vertexShader, vertexShaderSource);
});
callAndCheck(gl, function () {
return gl.compileShader(vertexShader);
});
if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
console.log(gl.getShaderInfoLog(vertexShader));
throw new Error('Failed to compile vertex shader.');
}
return vertexShader;
}
function createFragmentShader(gl, fragmentShaderSource) {
var fragmentShader = throwIfNull(gl, function () {
return gl.createShader(gl.FRAGMENT_SHADER);
}, 'Unable to create fragment WebGLShader.');
callAndCheck(gl, function () {
return gl.shaderSource(fragmentShader, fragmentShaderSource);
});
callAndCheck(gl, function () {
return gl.compileShader(fragmentShader);
});
if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
throw new Error('Failed to compile fragment shader.');
}
return fragmentShader;
}
var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
if (lineNumberRegexResult == null) {
console.log("Couldn't parse line number in error: " + shaderInfoLog);
console.log(shaderSource);
return;
}
var lineNumber = +lineNumberRegexResult[1];
var shaderLines = shaderSource.split('\n');
var pad = shaderLines.length.toString().length + 2;
var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) {
return rightPad((lineNumber + 1).toString(), pad) + line;
});
var maxLineLength = 0;
for (var i = 0; i < linesWithLineNumbers.length; i++) {
maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
}
var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
var afterErrorLines = linesWithLineNumbers.slice(lineNumber);
console.log(beforeErrorLines.join('\n'));
console.log(shaderInfoLog.split('\n')[0]);
console.log("%c " + rightPad(errorLine[0], maxLineLength), 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
console.log(afterErrorLines.join('\n'));
}
function createProgram(gl) {
return throwIfNull(gl, function () {
return gl.createProgram();
}, 'Unable to create WebGLProgram.');
}
function linkProgram(gl, program) {
callAndCheck(gl, function () {
return gl.linkProgram(program);
});
if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error('Failed to link vertex and fragment shaders.');
}
}
function validateProgram(gl, program) {
callAndCheck(gl, function () {
return gl.validateProgram(program);
});
if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error('Shader program validation failed.');
}
}
function createStaticVertexBuffer(gl, data) {
var buffer = throwIfNull(gl, function () {
return gl.createBuffer();
}, 'Unable to create WebGLBuffer');
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function () {
return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW);
});
return buffer;
}
function createStaticIndexBuffer(gl, data) {
var buffer = throwIfNull(gl, function () {
return gl.createBuffer();
}, 'Unable to create WebGLBuffer');
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function () {
return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW);
});
return buffer;
}
function getNumChannels() {
if (env().getNumber('WEBGL_VERSION') === 2) {
return 1;
}
return 4;
}
function createTexture(gl) {
return throwIfNull(gl, function () {
return gl.createTexture();
}, 'Unable to create WebGLTexture.');
}
function validateTextureSize(width, height) {
var maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
if (width <= 0 || height <= 0) {
var requested = "[" + width + "x" + height + "]";
throw new Error('Requested texture size ' + requested + ' is invalid.');
}
if (width > maxTextureSize || height > maxTextureSize) {
var _requested = "[" + width + "x" + height + "]";
var max = "[" + maxTextureSize + "x" + maxTextureSize + "]";
throw new Error('Requested texture size ' + _requested + ' greater than WebGL maximum on this browser / GPU ' + max + '.');
}
}
function createFramebuffer(gl) {
return throwIfNull(gl, function () {
return gl.createFramebuffer();
}, 'Unable to create WebGLFramebuffer.');
}
function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
var loc = gl.getAttribLocation(program, attribute);
if (loc === -1) {
// The GPU compiler decided to strip out this attribute because it's unused,
// thus no need to bind.
return false;
}
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
});
callAndCheck(gl, function () {
return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes);
});
callAndCheck(gl, function () {
return gl.enableVertexAttribArray(loc);
});
return true;
}
function bindTextureUnit(gl, texture, textureUnit) {
validateTextureUnit(gl, textureUnit);
callAndCheck(gl, function () {
return gl.activeTexture(gl.TEXTURE0 + textureUnit);
});
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
}
function unbindTextureUnit(gl, textureUnit) {
validateTextureUnit(gl, textureUnit);
callAndCheck(gl, function () {
return gl.activeTexture(gl.TEXTURE0 + textureUnit);
});
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function getProgramUniformLocationOrThrow(gl, program, uniformName) {
return throwIfNull(gl, function () {
return gl.getUniformLocation(program, uniformName);
}, 'uniform "' + uniformName + '" not present in program.');
}
function getProgramUniformLocation(gl, program, uniformName) {
return gl.getUniformLocation(program, uniformName);
}
function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
callAndCheck(gl, function () {
return bindTextureUnit(gl, texture, textureUnit);
});
callAndCheck(gl, function () {
return gl.uniform1i(uniformSamplerLocation, textureUnit);
});
}
function bindCanvasToFramebuffer(gl) {
callAndCheck(gl, function () {
return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
});
callAndCheck(gl, function () {
return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);
});
callAndCheck(gl, function () {
return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height);
});
}
function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
callAndCheck(gl, function () {
return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
});
callAndCheck(gl, function () {
return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
});
}
function unbindColorTextureFromFramebuffer(gl, framebuffer) {
callAndCheck(gl, function () {
return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
});
callAndCheck(gl, function () {
return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0);
});
}
function validateFramebuffer(gl) {
var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
if (status !== gl.FRAMEBUFFER_COMPLETE) {
throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
}
}
function getFramebufferErrorMessage(gl, status) {
switch (status) {
case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
case gl.FRAMEBUFFER_UNSUPPORTED:
return 'FRAMEBUFFER_UNSUPPORTED';
default:
return "unknown error " + status;
}
}
function throwIfNull(gl, returnTOrNull, failureMessage) {
var tOrNull = callAndCheck(gl, function () {
return returnTOrNull();
});
if (tOrNull == null) {
throw new Error(failureMessage);
}
return tOrNull;
}
function validateTextureUnit(gl, textureUnit) {
var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
var glTextureUnit = textureUnit + gl.TEXTURE0;
if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]";
throw new Error("textureUnit must be in " + textureUnitRange + ".");
}
}
function getBatchDim(shape, dimsToSkip) {
if (dimsToSkip === void 0) {
dimsToSkip = 2;
}
return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
}
function getRowsCols(shape) {
if (shape.length === 0) {
throw Error('Cannot get rows and columns of an empty shape array.');
}
return [shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
}
function getShapeAs3D(shape) {
var shapeAs3D = [1, 1, 1];
var isScalar = shape.length === 0 || shape.length === 1 && shape[0] === 1;
if (!isScalar) {
shapeAs3D = [getBatchDim(shape)].concat(getRowsCols(shape));
}
return shapeAs3D;
}
function getTextureShapeFromLogicalShape(logShape, isPacked) {
if (isPacked === void 0) {
isPacked = false;
}
var maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
if (isPacked) {
maxTexSize = maxTexSize * 2; // This logic ensures we accurately count the number of packed texels needed
// to accommodate the tensor. We can only pack values in the same texel if
// they are from adjacent pairs of rows/cols within the same batch. So if a
// tensor has 3 rows, we pretend it has 4 rows in order to account for the
// fact that the texels containing the third row are half empty.
logShape = logShape.map(function (d, i) {
return i >= logShape.length - 2 ? nearestLargerEven(logShape[i]) : logShape[i];
}); // Packed texture height is at least 2 (the channel height of a single
// texel).
if (logShape.length === 1) {
logShape = [2, logShape[0]];
}
} // If logical shape is 2, we don't squeeze, since we want to match physical.
if (logShape.length !== 2) {
var squeezeResult = squeezeShape(logShape);
logShape = squeezeResult.newShape;
}
var size = sizeFromShape(logShape);
if (logShape.length <= 1 && size <= maxTexSize) {
return [1, size];
} else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) {
return logShape;
} else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) {
return [logShape[0] * logShape[1], logShape[2]];
} else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2]];
} else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) {
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
} else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
} else {
if (isPacked) {
// For packed textures size equals the number of channels required to
// accommodate the texture data. However in order to squarify such that
// inner dimensions stay even, we rewrite size to equal the number of
// texels. Then in the return statement we rehydrate the squarified
// dimensions to channel units.
var batchDim = getBatchDim(logShape);
var rows = 2,
cols = 2;
if (logShape.length) {
var _getRowsCols = getRowsCols(logShape);
rows = _getRowsCols[0];
cols = _getRowsCols[1];
}
size = batchDim * (rows / 2) * (cols / 2);
return sizeToSquarishShape(size).map(function (d) {
return d * 2;
});
}
return sizeToSquarishShape(size);
}
}
function isEven(n) {
return n % 2 === 0;
}
/**
* This determines whether reshaping a packed texture requires rearranging
* the data within the texture, assuming 2x2 packing.
*/
function isReshapeFree(shape1, shape2) {
shape1 = shape1.slice(-2);
shape2 = shape2.slice(-2);
if (arraysEqual(shape1, shape2)) {
return true;
}
if (!shape1.length || !shape2.length) {
// One of the shapes is a scalar.
return true;
}
if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) {
return true;
}
if (shape1.length !== shape2.length) {
// One of the shapes is a vector.
var shape1Cols = shape1.slice(-1)[0];
var shape2Cols = shape2.slice(-1)[0];
if (shape1Cols === shape2Cols) {
return true;
}
if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) {
return true;
}
}
return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
} // We cache webgl params because the environment gets reset between
// unit tests and we don't want to constantly query the WebGLContext for
// MAX_TEXTURE_SIZE.
var MAX_TEXTURE_SIZE;
var MAX_TEXTURES_IN_SHADER;
function getWebGLMaxTextureSize(webGLVersion) {
if (MAX_TEXTURE_SIZE == null) {
var gl = getWebGLContext(webGLVersion);
MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
}
return MAX_TEXTURE_SIZE;
}
function resetMaxTextureSize() {
MAX_TEXTURE_SIZE = null;
}
function resetMaxTexturesInShader() {
MAX_TEXTURES_IN_SHADER = null;
}
function getMaxTexturesInShader(webGLVersion) {
if (MAX_TEXTURES_IN_SHADER == null) {
var gl = getWebGLContext(webGLVersion);
MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
} // We cap at 16 to avoid spurious runtime "memory exhausted" error.
return Math.min(16, MAX_TEXTURES_IN_SHADER);
}
function getWebGLDisjointQueryTimerVersion(webGLVersion) {
if (webGLVersion === 0) {
return 0;
}
var queryTimerVersion;
var gl = getWebGLContext(webGLVersion);
if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) {
queryTimerVersion = 2;
} else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
queryTimerVersion = 1;
} else {
queryTimerVersion = 0;
}
return queryTimerVersion;
}
function hasExtension(gl, extensionName) {
var ext = gl.getExtension(extensionName);
return ext != null;
}
function isWebGLVersionEnabled(webGLVersion) {
try {
var gl = getWebGLContext(webGLVersion);
if (gl != null) {
return true;
}
} catch (e) {
console.log('Error when getting WebGL context: ', e);
return false;
}
return false;
}
function isCapableOfRenderingToFloatTexture(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
var gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, 'OES_texture_float')) {
return false;
}
} else {
if (!hasExtension(gl, 'EXT_color_buffer_float')) {
return false;
}
}
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
/**
* Check if we can download values from a float/half-float texture.
*
* Note that for performance reasons we use binding a texture to a framebuffer
* as a proxy for ability to download float values later using readPixels. The
* texture params of this texture will not match those in readPixels exactly
* but if we are unable to bind some kind of float texture to the frameBuffer
* then we definitely will not be able to read float values from it.
*/
function isDownloadFloatTextureEnabled(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
var gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, 'OES_texture_float')) {
return false;
}
if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
return false;
}
} else {
if (hasExtension(gl, 'EXT_color_buffer_float')) {
return createFloatTextureAndBindToFramebuffer(gl);
}
var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
}
return false;
}
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
function createFloatTextureAndBindToFramebuffer(gl) {
var texConfig = getTextureConfig(gl);
var texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
var width = 1;
var height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
var frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function createHalfFloatTextureAndBindToFramebuffer( // tslint:disable-next-line:no-any
gl, textureHalfFloatExtension) {
var texConfig = getTextureConfig(gl, textureHalfFloatExtension);
var texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
var width = 1;
var height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
var frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function isWebGLFenceEnabled(webGLVersion) {
if (webGLVersion !== 2) {
return false;
}
var gl = getWebGLContext(webGLVersion); // tslint:disable-next-line:no-any
var isEnabled = gl.fenceSync != null;
return isEnabled;
}
function assertNotComplex$1(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(function (t) {
if (t != null) {
assert(t.dtype !== 'complex64', function () {
return opName + " does not support complex64 tensors " + 'in the WebGL backend.';
});
}
});
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ENV$1 = env();
/**
* This file contains WebGL-specific flag registrations.
*/
/**
* True if WebGL is supported.
*/
ENV$1.registerFlag('HAS_WEBGL', function () {
return ENV$1.getNumber('WEBGL_VERSION') > 0;
});
/** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
ENV$1.registerFlag('WEBGL_VERSION', function () {
if (isWebGLVersionEnabled(2)) {
return 2;
} else if (isWebGLVersionEnabled(1)) {
return 1;
}
return 0;
});
/** Whether to check for numerical representation problems. */
ENV$1.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', function () {
return false;
});
ENV$1.registerFlag('WEBGL_BUFFER_SUPPORTED', function () {
return ENV$1.get('WEBGL_VERSION') === 2;
});
/** Whether the WebGL backend will sometimes forward ops to the CPU. */
ENV$1.registerFlag('WEBGL_CPU_FORWARD', function () {
return true;
});
/** Whether the WebGL backend will always use f16 textures for rendering. */
ENV$1.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () {
return false;
});
/** Whether to turn all packing related flags on. */
ENV$1.registerFlag('WEBGL_PACK', function () {
return ENV$1.getBool('HAS_WEBGL');
});
/** Whether we will pack the batchnormalization op. */
ENV$1.registerFlag('WEBGL_PACK_NORMALIZATION', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack the clip op. */
ENV$1.registerFlag('WEBGL_PACK_CLIP', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack the depthwise conv op. */
ENV$1.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack binary ops. */
ENV$1.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack unary ops. */
ENV$1.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack array ops. */
ENV$1.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack image ops. */
ENV$1.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will pack reduce ops. */
ENV$1.registerFlag('WEBGL_PACK_REDUCE', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether packed WebGL kernels lazily unpack their outputs. */
ENV$1.registerFlag('WEBGL_LAZILY_UNPACK', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** Whether we will use the im2col algorithm to speed up convolutions. */
ENV$1.registerFlag('WEBGL_CONV_IM2COL', function () {
return ENV$1.getBool('WEBGL_PACK');
});
/** The maximum texture dimension. */
ENV$1.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () {
return getWebGLMaxTextureSize(ENV$1.getNumber('WEBGL_VERSION'));
});
/** The maximum texture dimension. */
ENV$1.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () {
return getMaxTexturesInShader(ENV$1.getNumber('WEBGL_VERSION'));
});
/**
* The disjoint_query_timer extension version.
* 0: disabled, 1: EXT_disjoint_timer_query, 2:
* EXT_disjoint_timer_query_webgl2.
* In Firefox with WebGL 2.0,
* EXT_disjoint_timer_query_webgl2 is not available, so we must use the
* WebGL 1.0 extension.
*/
ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () {
var webGLVersion = ENV$1.getNumber('WEBGL_VERSION');
if (webGLVersion === 0) {
return 0;
}
return getWebGLDisjointQueryTimerVersion(webGLVersion);
});
/**
* Whether the timer object from the disjoint_query_timer extension gives
* timing information that is reliable.
*/
ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () {
return ENV$1.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && !isMobile();
});
/**
* Whether the device is physically capable of rendering to float32 textures.
*/
ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () {
return isCapableOfRenderingToFloatTexture(ENV$1.getNumber('WEBGL_VERSION'));
});
/**
* Whether rendering to float32 textures is enabled. If disabled, renders to
* float16 textures.
*/
ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () {
return ENV$1.getBool('WEBGL_FORCE_F16_TEXTURES') ? false : ENV$1.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
});
/**
* Whether downloading float textures is enabled (16 or 32 bit). If disabled,
* uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
*/
ENV$1.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () {
return isDownloadFloatTextureEnabled(ENV$1.getNumber('WEBGL_VERSION'));
});
/** Whether the fence API is available. */
ENV$1.registerFlag('WEBGL_FENCE_API_ENABLED', function () {
return isWebGLFenceEnabled(ENV$1.getNumber('WEBGL_VERSION'));
});
/**
* Tensors with size <= than this will be uploaded as uniforms, not textures.
*/
ENV$1.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () {
// Use uniform uploads only when 32bit floats are supported. In
// 16bit
// environments there are problems with comparing a 16bit texture value
// with a 32bit uniform value.
var useUniforms = ENV$1.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
return useUniforms ? 4 : 0;
});
/**
* If the total number of bytes allocated on the GPU is greater than this
* number, we will aggressively delete textures upon disposal with
* gl.deleteMatrixTexture, rather than making them available for reuse.
*
* Default value -1 indicates that we will never aggressively delete textures.
*/
ENV$1.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', function () {
return -1;
}, function (threshold) {
if (threshold < 0 && threshold !== -1) {
throw new Error("WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never " + ("delete) or at least 0, but got " + threshold + "."));
}
});
/**
* Trigger a manual GL command flush if the threshold of time has passed since
* previous Kernel execution. This can be useful for Andorid device where GL
* command flush are delayed un til the end of javascript task. This value is
* measured in millisecond. Typically you want to set this value to close to 1.
*
* Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that
* we will not enforce manual flush and depend on system default flush schedule.
*/
ENV$1.registerFlag('WEBGL_FLUSH_THRESHOLD', function () {
return isMobile() && ENV$1.getBool('IS_CHROME') ? 1 : -1;
}, function (threshold) {
if (threshold < 0 && threshold !== -1) {
throw new Error("WEBGL_FLUSH_THRESHOLD must be -1 (indicating never " + ("manual flush) or at least 0, but got " + threshold + "."));
}
});
/**
* Threshold for input tensor size that determines whether WebGL backend will
* delegate computation to CPU.
*
* Default value is 128.
*/
ENV$1.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', function () {
return 128;
});
/** Whether we will use shapes uniforms. */
ENV$1.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', function () {
return false;
});
/**
* Threshold for last dimension of input tensor that determines whether
* WebGL backend for the Top K op will delegate computation to CPU. If input
* is smaller than threshold then CPU will be used
*
* Default value is 100000.
*/
ENV$1.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', function () {
return 100000;
});
/**
* Threshold for K that determines whether
* WebGL backend for the Top K op will delegate computation to CPU. If k
* is larger than threshold then CPU will be used
*
* Default value is 128.
*/
ENV$1.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', function () {
return 128;
});
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getGlslDifferences() {
var version;
var attribute;
var varyingVs;
var varyingFs;
var texture2D;
var output;
var defineOutput;
var defineSpecialNaN;
var defineSpecialInf;
var defineRound;
if (env().getNumber('WEBGL_VERSION') === 2) {
version = '#version 300 es';
attribute = 'in';
varyingVs = 'out';
varyingFs = 'in';
texture2D = 'texture';
output = 'outputColor';
defineOutput = 'out vec4 outputColor;'; // Use custom isnan definition to work across differences between
// implementations on various platforms. While this should happen in ANGLE
// we still see differences between android and windows (on chrome) when
// using isnan directly.
defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n "; // In webgl 2 we do not need to specify a custom isinf so there is no
// need for a special INFINITY constant.
defineSpecialInf = "";
defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
} else {
version = '';
attribute = 'attribute';
varyingVs = 'varying';
varyingFs = 'varying';
texture2D = 'texture2D';
output = 'gl_FragColor';
defineOutput = ''; // WebGL1 has no built in isnan so we define one here.
defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n ";
defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n ";
defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
}
return {
version: version,
attribute: attribute,
varyingVs: varyingVs,
varyingFs: varyingFs,
texture2D: texture2D,
output: output,
defineOutput: defineOutput,
defineSpecialNaN: defineSpecialNaN,
defineSpecialInf: defineSpecialInf,
defineRound: defineRound
};
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Produces GLSL code that derives logical coordinates from a flat
* index. The code performs integer division with each stride and decrements
* the index until the index equals the final dimension coordinate.
*/
function getLogicalCoordinatesFromFlatIndex(coords, shape, index) {
if (index === void 0) {
index = 'index';
}
var strides = computeStrides(shape);
return strides.map(function (stride, i) {
var line1 = "int " + coords[i] + " = " + index + " / " + stride;
var line2 = i === strides.length - 1 ? "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + stride : "index -= " + coords[i] + " * " + stride;
return line1 + "; " + line2 + ";";
}).join('');
}
function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index) {
if (index === void 0) {
index = 'index';
}
var strides = computeStrides(shape);
return strides.map(function (_, i) {
var line1 = "int " + coords[i] + " = " + index + " / outShapeStrides[" + i + "]";
var line2 = i === strides.length - 1 ? "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * outShapeStrides[" + i + "]" : "index -= " + coords[i] + " * outShapeStrides[" + i + "]";
return line1 + "; " + line2 + ";";
}).join('');
} // Produces GLSL code that computes strides.
function symbolicallyComputeStrides(indicesArr, variableName) {
var numCoords = indicesArr.length;
var shape = indicesArr.map(function (d) {
return variableName + "[" + d + "]";
});
var strides = new Array(numCoords - 1);
strides[numCoords - 2] = shape[numCoords - 1];
for (var i = numCoords - 3; i >= 0; --i) {
strides[i] = "(" + strides[i + 1] + " * " + shape[i + 1] + ")";
}
return strides;
}
function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index) {
if (index === void 0) {
index = 'index';
}
var indicesArray = coords.map(function (_, i) {
return i;
});
var strides = symbolicallyComputeStrides(indicesArray, variableName);
return strides.map(function (_, i) {
var line1 = "int " + coords[i] + " = " + index + " / " + strides[i];
var line2 = i === strides.length - 1 ? "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + strides[i] : "index -= " + coords[i] + " * " + strides[i];
return line1 + "; " + line2 + ";";
}).join('');
}
function buildVec(x) {
if (x.length === 1) {
return "" + x[0];
}
return "vec" + x.length + "(" + x.join(',') + ")";
}
/**
* Produces GLSL code that computes the dot product of the input x and y
* vectors. Handles splitting inputs into increments of vec4s when necessary.
*/
function dotify(x, y) {
if (x.length !== y.length) {
throw new Error("Vectors to be dotted must be of the same length -" + ("got " + x.length + " and " + y.length));
}
var slices = [];
var nearestVec4 = Math.floor(x.length / 4);
var nearestVec4Remainder = x.length % 4;
for (var i = 0; i < nearestVec4; i++) {
var xSlice = x.slice(i * 4, i * 4 + 4);
var ySlice = y.slice(i * 4, i * 4 + 4);
slices.push(buildVec(xSlice) + ", " + buildVec(ySlice));
}
if (nearestVec4Remainder !== 0) {
var _xSlice = x.slice(nearestVec4 * 4);
var _ySlice = y.slice(nearestVec4 * 4);
if (_xSlice.length === 1) {
_xSlice = _xSlice.map(function (d) {
return "float(" + d + ")";
});
_ySlice = _ySlice.map(function (d) {
return "float(" + d + ")";
});
}
slices.push(buildVec(_xSlice) + ", " + buildVec(_ySlice));
}
return slices.map(function (d, i) {
return "dot(" + d + ")";
}).join('+');
}
/**
* Produces GLSL that computes the flat index from 3D coordinates.
*/
function getFlatIndexFrom3D(shape) {
var strides = computeStrides(shape).map(function (d) {
return d.toString();
});
return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n";
}
function getFlatIndexFrom3DOutput() {
return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;\n }\n";
}
var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n";
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var getBroadcastDims$1 = getBroadcastDims;
function makeShader(inputsInfo, outputShape, program) {
var prefixSnippets = [];
inputsInfo.forEach(function (x) {
var size = sizeFromShape(x.shapeInfo.logicalShape); // Snippet when we decided to upload the values as uniform.
if (x.shapeInfo.isUniform) {
prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";");
} else {
prefixSnippets.push("uniform sampler2D " + x.name + ";");
prefixSnippets.push("uniform int offset" + x.name + ";");
}
if (program.enableShapeUniforms) {
var _getUniformInfoFromSh = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape),
uniformShape = _getUniformInfoFromSh.uniformShape;
switch (uniformShape.length) {
case 1:
prefixSnippets.push("uniform int " + x.name + "Shape;");
break;
case 2:
prefixSnippets.push("uniform ivec2 " + x.name + "Shape;");
break;
case 3:
prefixSnippets.push("uniform ivec3 " + x.name + "Shape;");
break;
case 4:
prefixSnippets.push("uniform ivec4 " + x.name + "Shape;");
break;
default:
break;
}
prefixSnippets.push("uniform ivec2 " + x.name + "TexShape;");
}
});
if (program.enableShapeUniforms) {
switch (outputShape.logicalShape.length) {
case 1:
prefixSnippets.push("uniform int outShape;");
break;
case 2:
prefixSnippets.push("uniform ivec2 outShape;");
prefixSnippets.push("uniform int outShapeStrides;");
break;
case 3:
prefixSnippets.push("uniform ivec3 outShape;");
prefixSnippets.push("uniform ivec2 outShapeStrides;");
break;
case 4:
prefixSnippets.push("uniform ivec4 outShape;");
prefixSnippets.push("uniform ivec3 outShapeStrides;");
break;
default:
break;
}
prefixSnippets.push("uniform ivec2 outTexShape;");
}
if (program.customUniforms) {
program.customUniforms.forEach(function (d) {
prefixSnippets.push("uniform " + d.type + " " + d.name + (d.arrayIndex ? "[" + d.arrayIndex + "]" : '') + ";");
});
}
var inputPrefixSnippet = prefixSnippets.join('\n');
var inputSamplingSnippet = inputsInfo.map(function (x) {
return getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms);
}).join('\n');
var outTexShape = outputShape.texShape;
var glsl = getGlslDifferences();
var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
var outputSamplingSnippet;
var floatTextureSetOutputSnippet;
var shaderPrefix = getShaderPrefix(glsl);
if (outputShape.isPacked) {
outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
} else {
outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
}
if (program.packedInputs) {
shaderPrefix += SHADER_PACKED_PREFIX;
}
var source = [shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, program.userCode].join('\n');
return source;
}
function getSamplerFromInInfo(inInfo, enableShapeUniforms) {
if (enableShapeUniforms === void 0) {
enableShapeUniforms = false;
}
var shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getSamplerScalar(inInfo, enableShapeUniforms);
case 1:
return getSampler1D(inInfo, enableShapeUniforms);
case 2:
return getSampler2D(inInfo, enableShapeUniforms);
case 3:
return getSampler3D(inInfo, enableShapeUniforms);
case 4:
return getSampler4D(inInfo, enableShapeUniforms);
case 5:
return getSampler5D(inInfo);
case 6:
return getSampler6D(inInfo);
default:
throw new Error(shape.length + "-D input sampling" + " is not yet supported");
}
}
function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
var shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getPackedSamplerScalar(inInfo);
case 1:
return getPackedSampler1D(inInfo, enableShapeUniforms);
case 2:
return getPackedSampler2D(inInfo, enableShapeUniforms);
case 3:
return getPackedSampler3D(inInfo, enableShapeUniforms);
default:
return getPackedSamplerND(inInfo, enableShapeUniforms);
}
}
function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures, enableShapeUniforms) {
if (usesPackedTextures === void 0) {
usesPackedTextures = false;
}
var res = '';
if (usesPackedTextures) {
res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
} else {
res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
}
var inShape = inInfo.shapeInfo.logicalShape;
var outShape = outShapeInfo.logicalShape;
if (inShape.length <= outShape.length) {
if (usesPackedTextures) {
res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
} else {
res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
}
}
return res;
}
function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
case 2:
return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
case 3:
return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
default:
return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
}
}
function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
case 2:
return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
case 3:
return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
case 4:
return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
case 5:
return getOutput5DCoords(outShape, outTexShape);
case 6:
return getOutput6DCoords(outShape, outTexShape);
default:
throw new Error(outShape.length + "-D output sampling is not yet supported");
}
}
function getFloatTextureSampleSnippet(glsl) {
return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n ";
}
function getFloatTextureSetRSnippet(glsl) {
return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n ";
}
function getFloatTextureSetRGBASnippet(glsl) {
return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n ";
}
function getShaderPrefix(glsl) {
var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n ";
return SHADER_PREFIX;
}
var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n";
function getOutputScalarCoords() {
return "\n int getOutputCoords() {\n return 0;\n }\n ";
}
function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (packedTexShape[0] === 1) {
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));\n }\n ";
}
return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n ";
}
if (packedTexShape[1] === 1) {
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));\n }\n ";
}
return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n ";
}
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);\n }\n ";
}
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n ";
}
function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
if (texShape[0] === 1) {
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n return int(resultUV.x * float(outTexShape[1]));\n }\n ";
}
return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n ";
}
if (texShape[1] === 1) {
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n return int(resultUV.y * float(outTexShape[0]));\n }\n ";
}
return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n ";
}
if (enableShapeUniforms) {
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n return resTexRC.x * outTexShape[1] + resTexRC.y;\n }\n ";
}
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n ";
}
function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
return "\n ivec3 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec3(b, r, c);\n }\n ";
}
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texelsInLogicalRow = Math.ceil(shape[2] / 2);
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n ";
}
function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
var _coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n " + _coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n";
}
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
}
function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
// TODO: support 5d and 6d
return "\n ivec4 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n\n int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));\n int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));\n int texelsInBatchN = texelsInBatch * outShape[1];\n\n int b2 = index / texelsInBatchN;\n index -= b2 * texelsInBatchN;\n\n int b = index / texelsInBatch;\n index -= b * texelsInBatch;\n\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec4(b2, b, r, c);\n }\n ";
}
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
var texelsInBatchN = texelsInBatch;
var batches = "";
var coords = 'b, r, c';
for (var b = 2; b < shape.length - 1; b++) {
texelsInBatchN *= shape[shape.length - b - 1];
batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches;
coords = "b" + b + ", " + coords;
}
return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords + ");\n }\n ";
}
function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
var _coordsFromIndexSnippet2 = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n " + _coordsFromIndexSnippet2 + "\n return ivec4(r, c, d, d2);\n }\n ";
}
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n ";
}
function getOutput5DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ";
}
function getOutput6DCoords(shape, texShape) {
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ";
}
function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n ";
} // texels needed to accommodate a logical row
var texelsInLogicalRow = Math.ceil(shape[1] / 2);
/**
* getOutputCoords
*
* resTexRC: The rows and columns of the texels. If you move over one
* texel to the right in the packed texture, you are moving over one column
* (not two).
*
* index: The texel index
*/
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));\n int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(packedTexShape[0], packedTexShape[1]));\n\n int index = resTexRC.x * packedTexShape[1] + resTexRC.y;\n int r = 2 * (index / texelsInLogicalRow);\n int c = imod(index, texelsInLogicalRow) * 2;\n\n return ivec2(r, c);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n ";
}
function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
if (arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n ";
}
if (shape[1] === 1) {
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
}
if (shape[0] === 1) {
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n return ivec2(0, index);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n ";
}
if (enableShapeUniforms) {
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(outTexShape[0], outTexShape[1]));\n int index = resTexRC.x * outTexShape[1] + resTexRC.y;\n int r = index / outShape[1];\n int c = index - r * outShape[1];\n return ivec2(r, c);\n }\n ";
}
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n ";
}
function getFlatOffsetUniformName(texName) {
return "offset" + texName;
}
function getPackedSamplerScalar(inputInfo) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var glsl = getGlslDifferences();
return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n ";
}
function getSamplerScalar(inputInfo, enableShapeUniforms) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return "float " + funcName + "() {return " + texName + ";}";
}
var _inputInfo$shapeInfo$ = inputInfo.shapeInfo.texShape,
texNumR = _inputInfo$shapeInfo$[0],
texNumC = _inputInfo$shapeInfo$[1];
if (texNumR === 1 && texNumC === 1) {
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + texName + "TexShape[0], " + texName + "TexShape[1], " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var _inputInfo$shapeInfo$2 = inputInfo.shapeInfo.texShape,
tNumR = _inputInfo$shapeInfo$2[0],
tNumC = _inputInfo$shapeInfo$2[1];
return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSampler1D(inputInfo, enableShapeUniforms) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var glsl = getGlslDifferences();
if (enableShapeUniforms) {
return "\n vec4 " + funcName + "(int index) {\n ivec2 packedTexShape = ivec2(ceil(float(" + texName + "TexShape[0]) / 2.0), ceil(float(" + texName + "TexShape[1]) / 2.0));\n vec2 uv = packedUVfrom1D(\n packedTexShape[0], packedTexShape[1], index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler1D(inputInfo, enableShapeUniforms) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var tNumR = texShape[0];
var tNumC = texShape[1];
if (tNumC === 1 && tNumR === 1) {
return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
if (tNumC === 1) {
if (enableShapeUniforms) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / float(" + texName + "TexShape[0]));\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (tNumR === 1) {
if (enableShapeUniforms) {
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / float(" + texName + "TexShape[1]), 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (enableShapeUniforms) {
return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + texName + "TexShape[0], " + texName + "TexShape[1], index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSampler2D(inputInfo, enableShapeUniforms) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
var glsl = getGlslDifferences();
if (texShape != null && arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
if (enableShapeUniforms) {
return "\n vec4 " + funcName + "(int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(" + texName + "TexShape[0]) / 2.0), ceil(float(" + texName + "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(" + texName + "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var valuesPerRow = Math.ceil(shape[1] / 2);
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler2D(inputInfo, enableShapeUniforms) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
if (texShape != null && arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var _texNumR = texShape[0];
var _texNumC = texShape[1];
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + _texNumC + ".0, " + _texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var _util$squeezeShape = squeezeShape(shape),
newShape = _util$squeezeShape.newShape,
keptDims = _util$squeezeShape.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ['row', 'col'];
return "\n " + getSamplerFromInInfo(newInputInfo, enableShapeUniforms) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texNumR = texShape[0];
var texNumC = texShape[1];
var offset = getFlatOffsetUniformName(texName);
if (texNumC === 1) {
// index is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + texName + "Shape[1], 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / float(" + texName + "TexShape[0]));\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumR === 1) {
// index is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + texName + "Shape[1], 1, 1));\n vec2 uv = vec2((index + 0.5) / float(" + texName + "TexShape[1]), 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + texName + "Shape[1] + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texName + "TexShape[0], " + texName + "TexShape[1], index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n";
}
function getPackedSampler3D(inputInfo, enableShapeUniforms) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var texShape = inputInfo.shapeInfo.texShape;
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (shape[0] === 1) {
var squeezedShape = shape.slice(1);
var keptDims = [1, 2];
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ['b', 'row', 'col'];
return "\n " + getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
var glsl = getGlslDifferences();
if (enableShapeUniforms) {
return "\n vec4 " + funcName + "(int b, int row, int col) {\n ivec2 packedTexShape = ivec2(ceil(float(" + texName + "TexShape[0]) / 2.0), ceil(float(" + texName + "TexShape[1]) / 2.0));\n int valuesPerRow = int(ceil(float(" + texName + "Shape[2]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(" + texName + "Shape[1]) / 2.0));\n vec2 uv = packedUVfrom3D(\n packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
var texNumR = packedTexShape[0];
var texNumC = packedTexShape[1];
var valuesPerRow = Math.ceil(shape[2] / 2);
var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler3D(inputInfo, enableShapeUniforms) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride0 = shape[1] * shape[2];
var stride1 = shape[2];
var _util$squeezeShape2 = squeezeShape(shape),
newShape = _util$squeezeShape2.newShape,
keptDims = _util$squeezeShape2.keptDims;
var squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
var params = ['row', 'col', 'depth'];
return "\n " + getSamplerFromInInfo(newInputInfo, enableShapeUniforms) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
var flatOffset = inputInfo.shapeInfo.flatOffset;
if (texNumC === stride0 && flatOffset == null) {
// texC is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth) {\n int stride1 = " + texName + "Shape[2];\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(stride1, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride1 && flatOffset == null) {
// texR is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + texName + "Shape[1], 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int stride0 = " + texName + "Shape[1] * " + texName + "Shape[2];\n int stride1 = " + texName + "Shape[2];\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texName + "TexShape[0], " + texName + "TexShape[1], index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getPackedSamplerND(inputInfo, enableShapeUniforms) {
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var glsl = getGlslDifferences();
if (enableShapeUniforms) {
// TODO: support 5d and 6d
return "\n vec4 " + funcName + "(int b2, int b, int row, int col) {\n int valuesPerRow = int(ceil(float(" + texName + "Shape[3]) / 2.0));\n int texelsInBatch = valuesPerRow * int(ceil(float(" + texName + "Shape[2]) / 2.0));\n int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);\n texelsInBatch *= " + texName + "Shape[1];\n index = b2 * texelsInBatch + index;\n ivec2 packedTexShape = ivec2(ceil(float(" + texName + "TexShape[0]) / 2.0), ceil(float(" + texName + "TexShape[1]) / 2.0));\n int texR = index / packedTexShape[1];\n int texC = index - texR * packedTexShape[1];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
var shape = inputInfo.shapeInfo.logicalShape;
var rank = shape.length;
var texShape = inputInfo.shapeInfo.texShape;
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
var texNumR = packedTexShape[0];
var texNumC = packedTexShape[1];
var valuesPerRow = Math.ceil(shape[rank - 1] / 2);
var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
var params = "int b, int row, int col";
var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)";
for (var b = 2; b < rank - 1; b++) {
params = "int b" + b + ", " + params;
texelsInBatch *= shape[rank - b - 1];
index = "b" + b + " * " + texelsInBatch + " + " + index;
}
return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
}
function getSampler4D(inputInfo, enableShapeUniforms) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride2 = shape[3];
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _util$squeezeShape3 = squeezeShape(shape),
newShape = _util$squeezeShape3.newShape,
keptDims = _util$squeezeShape3.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2'];
return "\n " + getSamplerFromInInfo(newInputInfo, enableShapeUniforms) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
var stride2Str = "int stride2 = " + texName + "Shape[3];";
var stride1Str = "int stride1 = " + texName + "Shape[2] * stride2;";
var stride0Str = "int stride0 = " + texName + "Shape[1] * stride1;";
if (texNumC === stride0 && flatOffset == null) {
// texC is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n " + stride2Str + "\n " + stride1Str + "\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(stride1, stride2, 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride2 && flatOffset == null) {
// texR is used directly as physical (no risk of float16 overflow).
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + texName + "Shape[1] * " + texName + "Shape[2], " + texName + "Shape[2], 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texName + "TexShape[1], " + texName + "TexShape[0]);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n " + stride2Str + "\n " + stride1Str + "\n " + stride0Str + "\n int index = row * stride0 + col * stride1 +\n depth * stride2 + depth2;\n vec2 uv = uvFromFlat(" + texName + "TexShape[0], " + texName + "TexShape[1], index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler5D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var stride3 = shape[4];
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
var _util$squeezeShape4 = squeezeShape(shape),
newShape = _util$squeezeShape4.newShape,
keptDims = _util$squeezeShape4.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2', 'depth3'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
// texC is used directly as physical (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride3 && flatOffset == null) {
// texR is used directly as physical (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getSampler6D(inputInfo) {
var shape = inputInfo.shapeInfo.logicalShape;
var texName = inputInfo.name;
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
var _util$squeezeShape5 = squeezeShape(shape),
newShape = _util$squeezeShape5.newShape,
keptDims = _util$squeezeShape5.keptDims;
if (newShape.length < shape.length) {
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
}
var stride4 = shape[5];
var stride3 = shape[4] * stride4;
var stride2 = shape[3] * stride3;
var stride1 = shape[2] * stride2;
var stride0 = shape[1] * stride1;
if (inputInfo.shapeInfo.isUniform) {
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
}
var flatOffset = inputInfo.shapeInfo.flatOffset;
var texShape = inputInfo.shapeInfo.texShape;
var texNumR = texShape[0];
var texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
// texC is used directly as physical (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
if (texNumC === stride4 && flatOffset == null) {
// texR is used directly as physical (no risk of float16 overflow).
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
var offset = getFlatOffsetUniformName(texName);
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
}
function getUniformSampler(inputInfo) {
var texName = inputInfo.name;
var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
if (inSize < 2) {
return "return " + texName + ";";
}
return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n ";
}
function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
var texName = inputInfo.name;
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
var broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var type = getCoordsDataType(outRank);
var rankDiff = outRank - inRank;
var coordsSnippet;
var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
if (inRank === 0) {
coordsSnippet = '';
} else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
} else {
coordsSnippet = broadcastDims.map(function (d) {
return "coords." + fields[d + rankDiff] + " = 0;";
}).join('\n');
}
var unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function (s, i) {
return "coords." + fields[i + rankDiff];
}).join(', ');
}
var output = "return outputValue;";
var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
var isInputScalar = inSize === 1;
var outSize = sizeFromShape(outShapeInfo.logicalShape);
var isOutputScalar = outSize === 1;
if (inRank === 1 && !isInputScalar && !isOutputScalar) {
output = "\n return vec4(outputValue.xy, outputValue.xy);\n ";
} else if (isInputScalar && !isOutputScalar) {
if (outRank === 1) {
output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n ";
} else {
output = "\n return vec4(outputValue.x);\n ";
}
} else if (broadcastDims.length) {
var rows = inRank - 2;
var cols = inRank - 1;
if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
output = "return vec4(outputValue.x);";
} else if (broadcastDims.indexOf(rows) > -1) {
output = "return vec4(outputValue.x, outputValue.y, " + "outputValue.x, outputValue.y);";
} else if (broadcastDims.indexOf(cols) > -1) {
output = "return vec4(outputValue.xx, outputValue.zz);";
}
}
return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n ";
}
function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
var texName = inputInfo.name;
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
var outTexShape = outShapeInfo.texShape;
var inTexShape = inputInfo.shapeInfo.texShape;
var inRank = inputInfo.shapeInfo.logicalShape.length;
var outRank = outShapeInfo.logicalShape.length;
if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && arraysEqual(inTexShape, outTexShape)) {
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n ";
}
var type = getCoordsDataType(outRank);
var broadcastDims = getBroadcastDims$1(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
var rankDiff = outRank - inRank;
var coordsSnippet;
var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
if (inRank === 0) {
coordsSnippet = '';
} else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
} else {
coordsSnippet = broadcastDims.map(function (d) {
return "coords." + fields[d + rankDiff] + " = 0;";
}).join('\n');
}
var unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape.map(function (s, i) {
return "coords." + fields[i + rankDiff];
}).join(', ');
}
return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n ";
}
function getCoordsDataType(rank) {
if (rank <= 1) {
return 'int';
} else if (rank === 2) {
return 'ivec2';
} else if (rank === 3) {
return 'ivec3';
} else if (rank === 4) {
return 'ivec4';
} else if (rank === 5) {
return 'ivec5';
} else if (rank === 6) {
return 'ivec6';
} else {
throw Error("GPU for rank " + rank + " is not yet supported");
}
}
function getUniformInfoFromShape(isPacked, shape, texShape) {
var _util$squeezeShape6 = squeezeShape(shape),
newShape = _util$squeezeShape6.newShape,
keptDims = _util$squeezeShape6.keptDims;
var rank = shape.length;
var useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
var squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
var useSqueezeShape = !isPacked && rank > 1 && !arraysEqual(shape, texShape) && newShape.length < rank || useSqueezePackedShape;
var uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
return {
useSqueezeShape: useSqueezeShape,
uniformShape: uniformShape,
keptDims: keptDims
};
}
/** Returns a new input info (a copy) that has a squeezed logical shape. */
function squeezeInputInfo(inInfo, squeezedShape) {
// Deep copy.
var newInputInfo = JSON.parse(JSON.stringify(inInfo));
newInputInfo.shapeInfo.logicalShape = squeezedShape;
return newInputInfo;
}
function getSqueezedParams(params, keptDims) {
return keptDims.map(function (d) {
return params[d];
}).join(', ');
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function compileProgram(gpgpu, program, inputs, output) {
var inputInfos = inputs.map(function (input, i) {
var shapeInfo = {
logicalShape: input.shape,
texShape: input.isUniform ? null : input.texData.texShape,
isUniform: input.isUniform,
isPacked: input.isUniform ? false : input.texData.isPacked,
flatOffset: null
};
if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) {
shapeInfo.flatOffset = input.texData.slice.flatOffset;
}
return {
name: program.variableNames[i],
shapeInfo: shapeInfo
};
});
var inShapeInfos = inputInfos.map(function (x) {
return x.shapeInfo;
});
var outShapeInfo = {
logicalShape: output.shape,
texShape: output.texData.texShape,
isUniform: false,
isPacked: output.texData.isPacked,
flatOffset: null
};
var source = makeShader(inputInfos, outShapeInfo, program);
var webGLProgram = gpgpu.createProgram(source); // Add special uniforms (NAN, INFINITY)
var infLoc = null;
var nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
if (env().getNumber('WEBGL_VERSION') === 1) {
infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
} // Add user-defined uniforms
var shouldThrow = false;
var uniformLocations = {};
var inShapesLocations = {};
var inTexShapesLocations = {};
for (var i = 0; i < program.variableNames.length; i++) {
var varName = program.variableNames[i];
uniformLocations[varName] = gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
uniformLocations["offset" + varName] = gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow);
if (program.enableShapeUniforms) {
inShapesLocations[varName + "Shape"] = gpgpu.getUniformLocation(webGLProgram, varName + "Shape", shouldThrow);
inTexShapesLocations[varName + "TexShape"] = gpgpu.getUniformLocation(webGLProgram, varName + "TexShape", shouldThrow);
}
}
var outShapeLocation;
var outTexShapeLocation;
var outShapeStridesLocation;
if (program.enableShapeUniforms) {
outShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
outShapeStridesLocation = gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
outTexShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
}
var customUniformLocations = [];
if (program.customUniforms) {
program.customUniforms.forEach(function (d, i) {
customUniformLocations[i] = gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow);
});
}
return {
program: program,
source: source,
webGLProgram: webGLProgram,
uniformLocations: uniformLocations,
customUniformLocations: customUniformLocations,
inShapeInfos: inShapeInfos,
outShapeInfo: outShapeInfo,
infLoc: infLoc,
nanLoc: nanLoc,
inShapesLocations: inShapesLocations,
inTexShapesLocations: inTexShapesLocations,
outShapeLocation: outShapeLocation,
outShapeStridesLocation: outShapeStridesLocation,
outTexShapeLocation: outTexShapeLocation
};
}
function validateBinaryAndProgram(shapeInfos, inputs) {
if (shapeInfos.length !== inputs.length) {
throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " + ("was executed with " + inputs.length + " inputs"));
}
shapeInfos.forEach(function (s, i) {
var shapeA = s.logicalShape;
var input = inputs[i];
var shapeB = input.shape;
if (!arraysEqual(shapeA, shapeB)) {
throw Error("Binary was compiled with different shapes than " + ("the current args. Shapes " + shapeA + " and " + shapeB + " must match"));
} // The input is uploaded as uniform.
if (s.isUniform && input.isUniform) {
return;
}
var texShapeA = s.texShape;
var texShapeB = input.isUniform ? null : input.texData.texShape;
if (!arraysEqual(texShapeA, texShapeB)) {
throw Error("Binary was compiled with different texture shapes than the" + (" current args. Shape " + texShapeA + " and " + texShapeB + " must match"));
}
});
}
function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
if (!binary.program.enableShapeUniforms) {
validateBinaryAndProgram(binary.inShapeInfos, inputs);
validateBinaryAndProgram([binary.outShapeInfo], [output]);
}
var outTex = output.texData.texture;
var outTexShape = output.texData.texShape;
if (output.texData.isPacked) {
gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
} else {
gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
}
gpgpu.setProgram(binary.webGLProgram); // Set special uniforms (NAN, INFINITY)
if (env().getNumber('WEBGL_VERSION') === 1) {
if (binary.infLoc !== null) {
gpgpu.gl.uniform1f(binary.infLoc, Infinity);
}
}
if (binary.nanLoc !== null) {
gpgpu.gl.uniform1f(binary.nanLoc, NaN);
} // Set user-defined inputs
inputs.forEach(function (input, i) {
var varName = binary.program.variableNames[i];
var varLoc = binary.uniformLocations[varName];
var varOffsetLoc = binary.uniformLocations["offset" + varName];
var varShapeLoc = binary.inShapesLocations[varName + "Shape"];
var varTexShapeLoc = binary.inTexShapesLocations[varName + "TexShape"];
if (varShapeLoc) {
var _shader_compiler$getU = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape),
uniformShape = _shader_compiler$getU.uniformShape;
switch (uniformShape.length) {
case 1:
gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 2:
gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 3:
gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 4:
gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
break;
default:
break;
}
}
if (varTexShapeLoc) {
gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
}
if (varLoc == null) {
// The compiler inferred that this variable is not used in this shader.
return;
}
if (input.isUniform) {
// Upload the values of the tensor as uniform.
if (sizeFromShape(input.shape) < 2) {
gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
} else {
var vals = input.uniformValues;
if (!(vals instanceof Float32Array)) {
vals = new Float32Array(vals);
}
gpgpu.gl.uniform1fv(varLoc, vals);
}
return;
} // If the input was sliced, upload the flat offset index.
if (input.texData.slice != null && varOffsetLoc != null) {
gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
}
gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
});
var outShapeLoc = binary.outShapeLocation;
if (outShapeLoc) {
switch (output.shape.length) {
case 1:
gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
break;
case 2:
gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
break;
case 3:
gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
break;
case 4:
gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
break;
default:
break;
}
}
if (binary.outShapeStridesLocation) {
var strides = computeStrides(output.shape);
switch (output.shape.length) {
case 2:
gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 3:
gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 4:
gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
default:
break;
}
}
if (binary.outTexShapeLocation) {
gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
}
if (binary.program.customUniforms && customUniformValues) {
binary.program.customUniforms.forEach(function (d, i) {
var customLoc = binary.customUniformLocations[i];
var customValue = customUniformValues[i];
if (d.type === 'float') {
gpgpu.gl.uniform1fv(customLoc, customValue);
} else if (d.type === 'vec2') {
gpgpu.gl.uniform2fv(customLoc, customValue);
} else if (d.type === 'vec3') {
gpgpu.gl.uniform3fv(customLoc, customValue);
} else if (d.type === 'vec4') {
gpgpu.gl.uniform4fv(customLoc, customValue);
} else if (d.type === 'int') {
gpgpu.gl.uniform1iv(customLoc, customValue);
} else if (d.type === 'ivec2') {
gpgpu.gl.uniform2iv(customLoc, customValue);
} else if (d.type === 'ivec3') {
gpgpu.gl.uniform3iv(customLoc, customValue);
} else if (d.type === 'ivec4') {
gpgpu.gl.uniform4iv(customLoc, customValue);
} else {
throw Error("uniform type " + d.type + " is not supported yet.");
}
});
}
gpgpu.executeProgram();
}
function makeShaderKey(program, inputs, output) {
var keyInputs = '';
inputs.concat(output).forEach(function (x) {
var hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0; // TODO: Remove the condition of !x.isUniform.
if (program.enableShapeUniforms && !x.isUniform) {
var xTexShape = x.texData.texShape;
var _shader_compiler$getU2 = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape),
useSqueezeShape = _shader_compiler$getU2.useSqueezeShape,
uniformShape = _shader_compiler$getU2.uniformShape,
keptDims = _shader_compiler$getU2.keptDims;
var rank1 = '',
rank2 = '',
rank34 = '';
if (uniformShape.length === 1 && program.packedInputs) {
var packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
rank1 = (packedTexShape[0] > 1) + "_" + (packedTexShape[1] > 1);
} else if (uniformShape.length === 2 && !program.packedInputs) {
rank2 = (uniformShape[0] > 1) + "_" + (uniformShape[1] > 1);
} else if (uniformShape.length > 2 && !program.packedInputs) {
var strides = computeStrides(uniformShape);
rank34 = (strides[0] === xTexShape[1]) + "_" + (strides[strides.length - 1] === xTexShape[1]);
}
var xRank = x.shape.length;
var isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
var isScalar = sizeFromShape(x.shape) === 1;
var broadcastDims = getBroadcastDims(x.shape, output.shape);
var isInOutTexShapeEqual = !program.packedInputs && xRank === output.shape.length && arraysEqual(xTexShape, output.texData.texShape);
var isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ? '' : (xTexShape[0] > 1) + "_" + (xTexShape[1] > 1); // These key components are needed due to shader_compiler is embedding
// them in the shader.
// |xRank| is used to determine the coords length. See
// get[Packed]SamplerAtOutputCoords.
// |isInOutTexShapeEqual| is used to determine whether going to an
// optimization path in getSamplerAtOutputCoords.
// |useSqueezeShape| is extracted from squeezeInputInfo of
// getSampler[2|3|4]D/getPackedSampler3D.
// |isScalar| is extracted from isInputScalar/isOutputScalar in
// getPackedSamplerAtOutputCoords.
// |broadcastDims| is extracted from get[Packed]SamplerAtOutputCoords.
// |isLogicalShapTexShapeEqual| is used in
// getOutput[Packed]2DCoords/get[Packed]Sampler2D.
// |rank1| is used in getOutputPacked1DCoords.
// |rank2| is used in getOutput2DCoords.
// |rank34| is used in getSampler3D/getSampler4D.
// |isTexShapeGreaterThanOne| are used in
// getSampler[Scalar|1D|2D]/getOutput1DCoords.
keyInputs += xRank + "_" + isInOutTexShapeEqual + "_" + (useSqueezeShape ? keptDims : '') + "_" + uniformShape.length + "_" + isScalar + "_" + broadcastDims + "_" + isLogicalShapTexShapeEqual + "_" + rank1 + "_" + rank2 + "_" + rank34 + "_" + isTexShapeGreaterThanOne + "_" + hasOffset;
} else {
var texShape = x.isUniform ? 'uniform' : x.texData.texShape;
keyInputs += x.shape + "_" + texShape + "_" + hasOffset;
}
});
var keyUserCode = program.userCode;
var key = program.constructor.name; // Fast string concat. See https://jsperf.com/string-concatenation/14.
key += '_' + keyInputs + '_' + keyUserCode + ("" + env().getNumber('WEBGL_VERSION'));
return key;
}
function useShapeUniforms(rank) {
// TODO: Remove the limitaion of rank <= 4.
return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DecodeMatrixProgram = function DecodeMatrixProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
this.customUniforms = [{
name: 'texShape',
type: 'ivec2'
}];
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + (this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DecodeMatrixPackedProgram = function DecodeMatrixPackedProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
this.customUniforms = [{
name: 'texShape',
type: 'ivec2'
}];
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + (this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));\n int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeFloatProgram = function EncodeFloatProgram(outputShape) {
this.variableNames = ['A'];
this.outTexUsage = TextureUsage.DOWNLOAD;
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeFloatPackedProgram = function EncodeFloatPackedProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = false;
this.outTexUsage = TextureUsage.DOWNLOAD;
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EncodeMatrixProgram = function EncodeMatrixProgram(outputShape, inputIsUnsignedByte) {
if (inputIsUnsignedByte === void 0) {
inputIsUnsignedByte = false;
}
this.variableNames = ['A'];
this.customUniforms = [{
name: 'texShape',
type: 'ivec2'
}];
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var output = "result";
if (inputIsUnsignedByte) {
output = "floor(result * 255. + 0.5)";
}
this.userCode = "\n " + (this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / texShape[1];\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/*
This is how the shader encodes a tensor with shape = [2, 3, 5]
(indices are [batch, row, col]).
000|001 002|003 004|xxx 020|021 022|023 024|xxx
------- ------- ------- ------- ------- -------
010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
100|101 102|103 104|xxx 120|121 122|123 124|xxx
------- ------- ------- ------- ------- -------
110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
Single texels contain only values from the same batch, and from adjacent rows
and columns.
*/
var EncodeMatrixPackedProgram = function EncodeMatrixPackedProgram(outputShape, inputIsUnsignedByte) {
if (inputIsUnsignedByte === void 0) {
inputIsUnsignedByte = false;
}
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
this.customUniforms = [{
name: 'texShape',
type: 'ivec2'
}];
var glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var mainLoop = '';
var output = 'result';
if (inputIsUnsignedByte) {
output = 'floor(result * 255. + 0.5)';
}
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
var channel = row * 2 + col;
mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + (this.enableShapeUniforms ? 'outShape[2]' : "" + outputShape[2]) + ") {\n localCoords[2] += " + col + ";\n if (localCoords[1] + " + row + " < " + (this.enableShapeUniforms ? 'outShape[1]' : "" + outputShape[1]) + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / texShape[1];\n int c = imod(flatIndex, texShape[1]);\n vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);\n values = " + glsl.texture2D + "(A, uv);\n\n if (offset == 0) {\n result[" + channel + "] = values[0];\n } else if (offset == 1) {\n result[" + channel + "] = values[1];\n } else if (offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n ";
}
}
this.userCode = "\n " + (this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n ";
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function createVertexShader$1(gl) {
var glsl = getGlslDifferences();
var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }";
return createVertexShader(gl, vertexShaderSource);
}
function createVertexBuffer(gl) {
// [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
return createStaticVertexBuffer(gl, vertexArray);
}
function createIndexBuffer(gl) {
// OpenGL (and WebGL) have "CCW == front" winding
var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
return createStaticIndexBuffer(gl, triangleVertexIndices);
}
function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
validateTextureSize(width, height);
var texture = createTexture(gl);
var tex2d = gl.TEXTURE_2D;
callAndCheck(gl, function () {
return gl.bindTexture(tex2d, texture);
});
callAndCheck(gl, function () {
return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
});
callAndCheck(gl, function () {
return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
});
callAndCheck(gl, function () {
return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
});
callAndCheck(gl, function () {
return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
});
callAndCheck(gl, function () {
return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null);
});
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
return texture;
}
function getInternalFormatForFloat32MatrixTexture(textureConfig) {
return textureConfig.internalFormatFloat;
}
function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
var _tex_util$getUnpacked = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getUnpacked[0],
height = _tex_util$getUnpacked[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
}
function getInternalFormatForFloat16MatrixTexture(textureConfig) {
return textureConfig.internalFormatHalfFloat;
}
function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
var _tex_util$getUnpacked2 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getUnpacked2[0],
height = _tex_util$getUnpacked2[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
}
function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
return textureConfig.downloadTextureFormat;
}
function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
var _tex_util$getUnpacked3 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getUnpacked3[0],
height = _tex_util$getUnpacked3[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
}
function getInternalFormatForPackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedFloat;
}
function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getPackedMa[0],
height = _tex_util$getPackedMa[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
}
function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedHalfFloat;
}
function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
var _tex_util$getPackedMa2 = getPackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getPackedMa2[0],
height = _tex_util$getPackedMa2[1];
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
}
function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
var posOffset = 0; // x is the first buffer element
var uvOffset = 3 * 4; // uv comes after [x y z]
var stride = 3 * 4 + 2 * 4; // xyz + uv, each entry is 4-byte float.
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
});
var success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
return success && bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
}
function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
var dataForUpload, texelDataType, internalFormat;
if (data instanceof Uint8Array) {
dataForUpload = new Uint8Array(width * height * 4);
texelDataType = gl.UNSIGNED_BYTE;
internalFormat = gl.RGBA;
} else {
dataForUpload = new Float32Array(width * height * 4);
texelDataType = gl.FLOAT;
internalFormat = textureConfig.internalFormatPackedFloat;
}
dataForUpload.set(data);
callAndCheck(gl, function () {
return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload);
});
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function uploadPixelDataToTexture(gl, texture, pixels) {
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, texture);
});
if (pixels.data instanceof Uint8Array) {
callAndCheck(gl, function () {
return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data);
});
} else {
callAndCheck(gl, function () {
return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels);
});
}
callAndCheck(gl, function () {
return gl.bindTexture(gl.TEXTURE_2D, null);
});
}
function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
// Create and bind the buffer.
var buffer = gl2.createBuffer();
callAndCheck(gl2, function () {
return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
}); // Initialize the buffer to the size of the texture in bytes.
var bytesPerFloat = 4;
var valuesPerTexel = 4;
var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
callAndCheck(gl2, function () {
return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ);
}); // Enqueue a command on the GPU command queue to copy of texture into the
// buffer.
callAndCheck(gl2, function () {
return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0);
});
callAndCheck(gl2, function () {
return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
});
return buffer;
}
function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
var gl2 = gl;
var downloadTarget = new Float32Array(size);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
var _tex_util$getUnpacked4 = getUnpackedMatrixTextureShapeWidthHeight(rows, columns),
w = _tex_util$getUnpacked4[0],
h = _tex_util$getUnpacked4[1];
var numChannels = 4;
var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
callAndCheck(gl, function () {
return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget);
}); // By wrapping the buffer in a Float32Array, we use native browser IEEE 754
// decoding of the 4 bytes that back each 32 bit float.
return new Float32Array(downloadTarget.buffer);
}
function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
var gl2 = gl;
var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
var packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
callAndCheck(gl, function () {
return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA);
});
return packedRGBA;
}
var GPGPUContext = /*#__PURE__*/function () {
function GPGPUContext(gl) {
this.outputTexture = null;
this.program = null;
this.disposed = false;
this.vertexAttrsAreBound = false;
this.itemsToPoll = [];
var glVersion = env().getNumber('WEBGL_VERSION');
if (gl != null) {
this.gl = gl;
setWebGLContext(glVersion, gl);
} else {
this.gl = getWebGLContext(glVersion);
} // WebGL 2.0 enables texture floats without an extension.
var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
if (env().getNumber('WEBGL_VERSION') === 1) {
var TEXTURE_FLOAT = 'OES_texture_float';
var TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
} else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
throw new Error('GL context does not support half float textures, yet the ' + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
}
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
} else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
throw new Error('GL context does not support color renderable half floats, yet ' + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
}
} else {
COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
} else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
} else {
throw new Error('GL context does not support color renderable floats');
}
}
this.vertexBuffer = createVertexBuffer(this.gl);
this.indexBuffer = createIndexBuffer(this.gl);
this.framebuffer = createFramebuffer(this.gl);
this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension);
}
var _proto = GPGPUContext.prototype;
_proto.dispose = function dispose() {
var _this = this;
if (this.disposed) {
return;
}
if (this.program != null) {
console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + ' This is probably a resource leak, delete the program with ' + 'GPGPUContext.deleteProgram before disposing.');
}
if (this.outputTexture != null) {
console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + 'texture. This is probably a resource leak, delete the output ' + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.');
}
var gl = this.gl;
callAndCheck(gl, function () {
return gl.finish();
});
callAndCheck(gl, function () {
return gl.bindFramebuffer(gl.FRAMEBUFFER, null);
});
callAndCheck(gl, function () {
return gl.deleteFramebuffer(_this.framebuffer);
});
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ARRAY_BUFFER, null);
});
callAndCheck(gl, function () {
return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null);
});
callAndCheck(gl, function () {
return gl.deleteBuffer(_this.indexBuffer);
});
this.disposed = true;
};
_proto.createFloat32MatrixTexture = function createFloat32MatrixTexture$1(rows, columns) {
this.throwIfDisposed();
return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
};
_proto.createFloat16MatrixTexture = function createFloat16MatrixTexture$1(rows, columns) {
this.throwIfDisposed();
return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
};
_proto.createUnsignedBytesMatrixTexture = function createUnsignedBytesMatrixTexture$1(rows, columns) {
this.throwIfDisposed();
return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
_proto.uploadPixelDataToTexture = function uploadPixelDataToTexture$1(texture, pixels) {
this.throwIfDisposed();
uploadPixelDataToTexture(this.gl, texture, pixels);
};
_proto.uploadDenseMatrixToTexture = function uploadDenseMatrixToTexture$1(texture, width, height, data) {
this.throwIfDisposed();
uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
};
_proto.createFloat16PackedMatrixTexture = function createFloat16PackedMatrixTexture$1(rows, columns) {
this.throwIfDisposed();
return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
_proto.createPackedMatrixTexture = function createPackedMatrixTexture$1(rows, columns) {
this.throwIfDisposed();
return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
};
_proto.deleteMatrixTexture = function deleteMatrixTexture(texture) {
var _this2 = this;
this.throwIfDisposed();
if (this.outputTexture === texture) {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
this.outputTexture = null;
}
callAndCheck(this.gl, function () {
return _this2.gl.deleteTexture(texture);
});
};
_proto.downloadByteEncodedFloatMatrixFromOutputTexture = function downloadByteEncodedFloatMatrixFromOutputTexture$1(texture, rows, columns) {
var _this3 = this;
return this.downloadMatrixDriver(texture, function () {
return downloadByteEncodedFloatMatrixFromOutputTexture(_this3.gl, rows, columns, _this3.textureConfig);
});
};
_proto.downloadPackedMatrixFromBuffer = function downloadPackedMatrixFromBuffer$1(buffer, batch, rows, columns, physicalRows, physicalCols) {
return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
};
_proto.downloadFloat32MatrixFromBuffer = function downloadFloat32MatrixFromBuffer$1(buffer, size) {
return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
};
_proto.createBufferFromTexture = function createBufferFromTexture(texture, rows, columns) {
this.bindTextureToFrameBuffer(texture);
var result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
this.unbindTextureToFrameBuffer();
return result;
};
_proto.createAndWaitForFence = function createAndWaitForFence() {
var fenceContext = this.createFence(this.gl);
return this.pollFence(fenceContext);
};
_proto.createFence = function createFence(gl) {
var _this4 = this;
var query;
var isFencePassed;
if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
var gl2 = gl;
var sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
gl.flush();
isFencePassed = function isFencePassed() {
var status = gl2.clientWaitSync(sync, 0, 0);
return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED;
};
query = sync;
} else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
query = this.beginQuery();
this.endQuery();
isFencePassed = function isFencePassed() {
return _this4.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
};
} else {
// If we have no way to fence, return true immediately. This will fire in
// WebGL 1.0 when there is no disjoint query timer. In this case, because
// the fence passes immediately, we'll immediately ask for a download of
// the texture, which will cause the UI thread to hang.
isFencePassed = function isFencePassed() {
return true;
};
}
return {
query: query,
isFencePassed: isFencePassed
};
};
_proto.downloadMatrixFromPackedTexture = function downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
var _this5 = this;
return this.downloadMatrixDriver(texture, function () {
return downloadMatrixFromPackedOutputTexture(_this5.gl, physicalRows, physicalCols);
});
};
_proto.createProgram = function createProgram$1(fragmentShaderSource) {
var _this6 = this;
this.throwIfDisposed();
var gl = this.gl;
var fragmentShader = createFragmentShader(gl, fragmentShaderSource);
if (this.vertexShader == null) {
this.vertexShader = createVertexShader$1(gl);
}
var program = createProgram(gl);
callAndCheck(gl, function () {
return gl.attachShader(program, _this6.vertexShader);
});
callAndCheck(gl, function () {
return gl.attachShader(program, fragmentShader);
});
linkProgram(gl, program);
if (this.debug) {
validateProgram(gl, program);
}
if (!this.vertexAttrsAreBound) {
this.setProgram(program);
this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.program, this.vertexBuffer);
}
return program;
};
_proto.deleteProgram = function deleteProgram(program) {
var _this7 = this;
this.throwIfDisposed();
if (program === this.program) {
this.program = null;
}
if (program != null) {
callAndCheck(this.gl, function () {
return _this7.gl.deleteProgram(program);
});
}
};
_proto.setProgram = function setProgram(program) {
var _this8 = this;
this.throwIfDisposed();
this.program = program;
if (this.program != null && this.debug) {
validateProgram(this.gl, this.program);
}
callAndCheck(this.gl, function () {
return _this8.gl.useProgram(program);
});
};
_proto.getUniformLocation = function getUniformLocation(program, uniformName, shouldThrow) {
if (shouldThrow === void 0) {
shouldThrow = true;
}
this.throwIfDisposed();
if (shouldThrow) {
return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
} else {
return getProgramUniformLocation(this.gl, program, uniformName);
}
};
_proto.getAttributeLocation = function getAttributeLocation(program, attribute) {
var _this9 = this;
this.throwIfDisposed();
return callAndCheck(this.gl, function () {
return _this9.gl.getAttribLocation(program, attribute);
});
};
_proto.getUniformLocationNoThrow = function getUniformLocationNoThrow(program, uniformName) {
this.throwIfDisposed();
return this.gl.getUniformLocation(program, uniformName);
};
_proto.setInputMatrixTexture = function setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
this.throwIfDisposed();
this.throwIfNoProgram();
bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
};
_proto.setOutputMatrixTexture = function setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
};
_proto.setOutputPackedMatrixTexture = function setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
this.throwIfDisposed();
var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(rows, columns),
width = _tex_util$getPackedMa[0],
height = _tex_util$getPackedMa[1];
this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
};
_proto.setOutputMatrixWriteRegion = function setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
};
_proto.setOutputPackedMatrixWriteRegion = function setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
};
_proto.debugValidate = function debugValidate() {
if (this.program != null) {
validateProgram(this.gl, this.program);
}
validateFramebuffer(this.gl);
};
_proto.executeProgram = function executeProgram() {
this.throwIfDisposed();
this.throwIfNoProgram();
var gl = this.gl;
if (this.debug) {
this.debugValidate();
}
callAndCheck(gl, function () {
return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0);
});
};
_proto.blockUntilAllProgramsCompleted = function blockUntilAllProgramsCompleted() {
var _this10 = this;
this.throwIfDisposed();
callAndCheck(this.gl, function () {
return _this10.gl.finish();
});
};
_proto.getQueryTimerExtension = function getQueryTimerExtension() {
if (this.disjointQueryTimerExtension == null) {
this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query');
}
return this.disjointQueryTimerExtension;
};
_proto.getQueryTimerExtensionWebGL2 = function getQueryTimerExtensionWebGL2() {
return this.getQueryTimerExtension();
};
_proto.getQueryTimerExtensionWebGL1 = function getQueryTimerExtensionWebGL1() {
return this.getQueryTimerExtension();
};
_proto.beginQuery = function beginQuery() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
var gl2 = this.gl;
var _ext = this.getQueryTimerExtensionWebGL2();
var _query = gl2.createQuery();
gl2.beginQuery(_ext.TIME_ELAPSED_EXT, _query);
return _query;
}
var ext = this.getQueryTimerExtensionWebGL1();
var query = ext.createQueryEXT();
ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
return query;
};
_proto.endQuery = function endQuery() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
var gl2 = this.gl;
var _ext2 = this.getQueryTimerExtensionWebGL2();
gl2.endQuery(_ext2.TIME_ELAPSED_EXT);
return;
}
var ext = this.getQueryTimerExtensionWebGL1();
ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
};
_proto.waitForQueryAndGetTime = /*#__PURE__*/function () {
var _waitForQueryAndGetTime = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(query) {
var _this11 = this;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
_context.next = 2;
return repeatedTry(function () {
return _this11.disposed || // while testing contexts are created / disposed
// in rapid succession, so without this check we
// may poll for the query timer indefinitely
_this11.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
});
case 2:
return _context.abrupt("return", this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
case 3:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function waitForQueryAndGetTime(_x) {
return _waitForQueryAndGetTime.apply(this, arguments);
}
return waitForQueryAndGetTime;
}();
_proto.getQueryTime = function getQueryTime(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return null;
}
if (queryTimerVersion === 2) {
var gl2 = this.gl;
var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); // Return milliseconds.
return timeElapsedNanos / 1000000;
} else {
var ext = this.getQueryTimerExtensionWebGL1();
var _timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); // Return milliseconds.
return _timeElapsedNanos / 1000000;
}
};
_proto.isQueryAvailable = function isQueryAvailable(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return true;
}
if (queryTimerVersion === 2) {
var gl2 = this.gl;
var ext = this.getQueryTimerExtensionWebGL2();
var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
}
return available && !this.disjoint;
} else {
var _ext3 = this.getQueryTimerExtensionWebGL1();
var _available = _ext3.getQueryObjectEXT(query, _ext3.QUERY_RESULT_AVAILABLE_EXT);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(_ext3.GPU_DISJOINT_EXT);
}
return _available && !this.disjoint;
}
};
_proto.pollFence = function pollFence(fenceContext) {
var _this12 = this;
return new Promise(function (resolve) {
_this12.addItemToPoll(function () {
return fenceContext.isFencePassed();
}, function () {
return resolve();
});
});
};
_proto.pollItems = function pollItems() {
// Find the last query that has finished.
var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) {
return x.isDoneFn;
}));
for (var i = 0; i <= index; ++i) {
var resolveFn = this.itemsToPoll[i].resolveFn;
resolveFn();
}
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
};
_proto.addItemToPoll = function addItemToPoll(isDoneFn, resolveFn) {
var _this13 = this;
this.itemsToPoll.push({
isDoneFn: isDoneFn,
resolveFn: resolveFn
});
if (this.itemsToPoll.length > 1) {
// We already have a running loop that polls.
return;
} // Start a new loop that polls.
repeatedTry(function () {
_this13.pollItems(); // End the loop if no more items to poll.
return _this13.itemsToPoll.length === 0;
});
};
_proto.bindTextureToFrameBuffer = function bindTextureToFrameBuffer(texture) {
this.throwIfDisposed();
bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
};
_proto.unbindTextureToFrameBuffer = function unbindTextureToFrameBuffer() {
if (this.outputTexture != null) {
bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
} else {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
}
};
_proto.downloadMatrixDriver = function downloadMatrixDriver(texture, downloadAndDecode) {
this.bindTextureToFrameBuffer(texture);
var result = downloadAndDecode();
this.unbindTextureToFrameBuffer();
return result;
};
_proto.setOutputMatrixTextureDriver = function setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
this.throwIfDisposed();
var gl = this.gl;
bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
if (this.debug) {
validateFramebuffer(gl);
}
this.outputTexture = outputMatrixTextureMaybePacked;
callAndCheck(gl, function () {
return gl.viewport(0, 0, width, height);
});
callAndCheck(gl, function () {
return gl.scissor(0, 0, width, height);
});
};
_proto.setOutputMatrixWriteRegionDriver = function setOutputMatrixWriteRegionDriver(x, y, width, height) {
var _this14 = this;
this.throwIfDisposed();
callAndCheck(this.gl, function () {
return _this14.gl.scissor(x, y, width, height);
});
};
_proto.throwIfDisposed = function throwIfDisposed() {
if (this.disposed) {
throw new Error('Attempted to use disposed GPGPUContext.');
}
};
_proto.throwIfNoProgram = function throwIfNoProgram() {
if (this.program == null) {
throw new Error('No GPU program is currently set.');
}
};
_createClass(GPGPUContext, [{
key: "debug",
get: function get() {
return env().getBool('DEBUG');
}
}]);
return GPGPUContext;
}();
/**
* Finds the index of the last true element using linear search.
* Note: We can't do binary search because Chrome expects us to explicitly
* test all fences before download:
* https://github.com/tensorflow/tfjs/issues/1145
*/
function linearSearchLastTrue(arr) {
var i = 0;
for (; i < arr.length; ++i) {
var isDone = arr[i]();
if (!isDone) {
break;
}
}
return i - 1;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var addImplCPU = addImpl,
bincountImplCPU = bincountImpl,
bincountReduceImplCPU = bincountReduceImpl,
ceilImplCPU = ceilImpl,
concatImplCPU = concatImpl,
equalImplCPU = equalImpl,
expImplCPU = expImpl,
expm1ImplCPU = expm1Impl,
floorImplCPU = floorImpl,
gatherNdImplCPU = gatherNdImpl,
gatherV2ImplCPU = gatherV2Impl,
greaterImplCPU = greaterImpl,
greaterEqualImplCPU = greaterEqualImpl,
lessImplCPU = lessImpl,
lessEqualImplCPU = lessEqualImpl,
linSpaceImplCPU = linSpaceImpl,
logImplCPU = logImpl,
maxImplCPU = maxImpl,
maximumImplCPU = maximumImpl,
minimumImplCPU = minimumImpl,
multiplyImplCPU = multiplyImpl,
negImplCPU = negImpl,
notEqualImplCPU = notEqualImpl,
prodImplCPU = prodImpl,
rangeImplCPU = rangeImpl,
rsqrtImplCPU = rsqrtImpl,
sigmoidImplCPU = sigmoidImpl,
simpleAbsImplCPU = simpleAbsImpl,
sliceImplCPU = sliceImpl,
sparseFillEmptyRowsImplCPU = sparseFillEmptyRowsImpl,
sparseReshapeImplCPU = sparseReshapeImpl,
sparseSegmentReductionImplCPU = sparseSegmentReductionImpl,
sqrtImplCPU = sqrtImpl,
stridedSliceImplCPU = stridedSliceImpl,
stringNGramsImplCPU = stringNGramsImpl,
stringSplitImplCPU = stringSplitImpl,
stringToHashBucketFastImplCPU = stringToHashBucketFastImpl,
subImplCPU = subImpl,
tileImplCPU = tileImpl,
topKImplCPU = topKImpl,
transposeImplCPU = transposeImpl,
uniqueImplCPU = uniqueImpl;
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function getVecChannels(name, rank) {
return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) {
return name + "." + d;
});
}
function getChannels(name, rank) {
if (rank === 1) {
return [name];
}
return getVecChannels(name, rank);
}
function getSourceCoords(rank, dims) {
if (rank === 1) {
return 'rc';
}
var coords = '';
for (var i = 0; i < rank; i++) {
coords += dims[i];
if (i < rank - 1) {
coords += ',';
}
}
return coords;
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PackProgram = function PackProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true; // Only input / output 3D tensors.
this.outputShape = outputShape;
var rank = outputShape.length;
if (rank === 0) {
this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n ";
} else {
var channels = getChannels('rc', rank);
var dtype = getCoordsDataType(rank);
var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels);
var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels);
var output = getOutput(outputShape, channels);
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n ";
}
};
function getSourceCoordsArr(rank, dims) {
var coords = [];
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
var coord = (row === 0 ? 'r' : 'rp1') + ", " + (col === 0 ? 'c' : 'cp1');
for (var d = 2; d < rank; d++) {
coord = dims[dims.length - 1 - d] + "," + coord;
}
coords.push(coord);
}
}
return coords;
}
function getOutOfBoundsCondition(rank, shape, dims) {
if (rank === 1) {
return "rc > " + shape[0];
}
var cond = '';
for (var i = rank - 2; i < rank; i++) {
cond += dims[i] + " >= " + shape[i];
if (i < rank - 1) {
cond += '||';
}
}
return cond;
}
function getSetup(rank, cols, rows, dims) {
if (rank === 1) {
return '';
}
var innerDims = dims.slice(-2);
return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n ";
}
function getOutput(shape, dims) {
var rank = shape.length;
var sourceCoords = getSourceCoordsArr(rank, dims);
if (rank === 1) {
return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0";
}
return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")";
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReshapePackedProgram = function ReshapePackedProgram(outputShape, inputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{
name: 'inputShape',
type: 'ivec3'
}];
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var mainLoop = "";
for (var i = 0; i < 4; i++) {
var thisRC = "thisRC = rc;";
if (i % 2 === 1) {
thisRC += "thisRC.z += 1;";
}
if (i > 1) {
thisRC += "thisRC.y += 1;";
}
mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '') + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? '}' : '') + "\n ";
}
this.userCode = "\n " + getReshapedInputCoords(inputShape, this.enableShapeUniforms) + "\n " + (this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + (this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]) + ";\n int cols = " + (this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]) + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n ";
};
function getReshapedInputCoords(shape, enableShapeUniforms) {
var coordsFromIndexSnippet = enableShapeUniforms ? getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
}
var TextureManager = /*#__PURE__*/function () {
function TextureManager(gpgpu) {
this.gpgpu = gpgpu;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0; // How many bytes that have been allocated
// are available for reuse.
this.freeTextures = {};
this.logEnabled = false;
this.usedTextures = {};
}
var _proto = TextureManager.prototype;
_proto.acquireTexture = function acquireTexture(shapeRC, usage, isPacked) {
var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
if (!(shapeKey in this.usedTextures)) {
this.usedTextures[shapeKey] = [];
}
var texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
if (this.freeTextures[shapeKey].length > 0) {
this.numFreeTextures--;
this.numUsedTextures++;
this._numBytesFree -= texBytes;
this.log();
var _newTexture = this.freeTextures[shapeKey].shift();
this.usedTextures[shapeKey].push(_newTexture);
return _newTexture;
}
var newTexture;
if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
} else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
}
this.usedTextures[shapeKey].push(newTexture);
this.numUsedTextures++;
this._numBytesAllocated += texBytes;
this.log();
return newTexture;
};
_proto.releaseTexture = function releaseTexture(texture, shape, logicalTexType, isPacked) {
if (this.freeTextures == null) {
// Already disposed.
return;
}
var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
var texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
var deleteTexThreshold = env().get('WEBGL_DELETE_TEXTURE_THRESHOLD');
if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) {
this.gpgpu.deleteMatrixTexture(texture);
this._numBytesAllocated -= texBytes;
} else {
this.freeTextures[shapeKey].push(texture);
this.numFreeTextures++;
this._numBytesFree += texBytes;
}
this.numUsedTextures--;
var texList = this.usedTextures[shapeKey];
var texIndex = texList.indexOf(texture);
if (texIndex < 0) {
throw new Error('Cannot release a texture that was never provided by this ' + 'texture manager');
}
texList.splice(texIndex, 1);
this.log();
};
_proto.log = function log() {
if (!this.logEnabled) {
return;
}
var total = this.numFreeTextures + this.numUsedTextures;
console.log('Free/Used', this.numFreeTextures + " / " + this.numUsedTextures, "(" + total + ")");
var freeRatio = this._numBytesFree / this._numBytesAllocated;
console.log("Bytes allocated: " + this._numBytesAllocated);
console.log("Bytes unused: " + this._numBytesFree + " (" + Math.round(100 * freeRatio) + "%)");
};
_proto.getNumUsedTextures = function getNumUsedTextures() {
return this.numUsedTextures;
};
_proto.getNumFreeTextures = function getNumFreeTextures() {
return this.numFreeTextures;
};
_proto.dispose = function dispose() {
var _this = this;
if (this.freeTextures == null) {
// Already disposed.
return;
}
for (var texShape in this.freeTextures) {
this.freeTextures[texShape].forEach(function (tex) {
_this.gpgpu.deleteMatrixTexture(tex);
});
}
for (var _texShape in this.usedTextures) {
this.usedTextures[_texShape].forEach(function (tex) {
_this.gpgpu.deleteMatrixTexture(tex);
});
}
this.freeTextures = null;
this.usedTextures = null;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0;
};
_createClass(TextureManager, [{
key: "numBytesAllocated",
get: function get() {
return this._numBytesAllocated;
}
}, {
key: "numBytesFree",
get: function get() {
return this._numBytesFree;
}
}]);
return TextureManager;
}();
function numBytesForInternalFormat(gl, internalFormat) {
// tslint:disable-next-line:no-any
var glany = gl;
if (internalFormat === glany.R32F) {
return 4;
} else if (internalFormat === glany.R16F) {
return 2;
} else if (internalFormat === glany.RGBA32F) {
return 16;
} else if (internalFormat === gl.RGBA) {
return 16;
} else if (internalFormat === glany.RGBA16F) {
return 8;
}
throw new Error("Unknown internal format " + internalFormat);
}
function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
// It is not possible to infer packed status from the texture type because
// depending on the textureConfig, different texture types may resolve to the
// same internal format (e.g. in WebGL1, the internal format for
// UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
// explicitly.
var internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
var numElements;
if (isPacked) {
var _getPackedMatrixTextu = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]),
packedWidth = _getPackedMatrixTextu[0],
packedHeight = _getPackedMatrixTextu[1];
numElements = packedWidth * packedHeight;
} else {
var _getUnpackedMatrixTex = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]),
width = _getUnpackedMatrixTex[0],
height = _getUnpackedMatrixTex[1];
numElements = width * height;
}
var bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
return numElements * bytesPerElement;
}
function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
switch (physicalTexType) {
case PhysicalTextureType.PACKED_2X2_FLOAT32:
return getInternalFormatForPackedMatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_2X2_FLOAT16:
return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT32:
return getInternalFormatForFloat32MatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT16:
return getInternalFormatForFloat16MatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
default:
throw new Error("Unknown physical texture type " + physicalTexType);
}
}
function getPhysicalTextureForRendering(isPacked) {
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
}
return PhysicalTextureType.UNPACKED_FLOAT32;
}
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT16;
}
return PhysicalTextureType.UNPACKED_FLOAT16;
}
function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
if (logicalTexType === TextureUsage.UPLOAD) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
} else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
return getPhysicalTextureForRendering(isPacked);
} else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) {
return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
}
throw new Error("Unknown logical texture type " + logicalTexType);
}
function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
return shapeRowsCol[0] + "_" + shapeRowsCol[1] + "_" + physicalTexType + "_" + isPacked;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var UnaryOpProgram = function UnaryOpProgram(aShape, opSnippet) {
this.variableNames = ['A'];
this.outputShape = aShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = "\n float unaryOperation(float x) {\n " + opSnippet + "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
};
var CHECK_NAN_SNIPPET = "if (isnan(x)) return x;";
var LINEAR = "return x;";
var ABS = "return abs(x);";
function STEP(alpha) {
if (alpha === void 0) {
alpha = 0.0;
}
return CHECK_NAN_SNIPPET + ("\n return x > 0.0 ? 1.0 : float(" + alpha + ");\n ");
}
var ELU$1 = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
var RELU = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : x;\n";
var RELU6 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
var CLONE = 'return x;';
var SIGMOID = "return 1.0 / (1.0 + exp(-1.0 * x));";
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LINEAR$1 = "return x;";
var ELU$2 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var SIGMOID$1 = "return 1.0 / (1.0 + exp(-1.0 * x));";
var UnaryOpPackedProgram = function UnaryOpPackedProgram(aShape, opSnippet) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = aShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = "\n vec4 unaryOperation(vec4 x) {\n " + opSnippet + "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var UnpackProgram = function UnpackProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = false;
this.outputShape = outputShape;
var rank = outputShape.length;
var channels = getChannels('rc', rank);
var dtype = getCoordsDataType(rank);
var sourceCoords = getSourceCoords(rank, channels);
var innerDims = channels.slice(-2);
var coords = rank <= 1 ? 'rc' : "vec2(" + innerDims.join(',') + ")";
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 packedInput = getA(" + sourceCoords + ");\n\n setOutput(getChannel(packedInput, " + coords + "));\n }\n ";
};
var whereImpl$2 = whereImpl;
var EPSILON_FLOAT32$1 = 1e-7;
var EPSILON_FLOAT16$1 = 1e-4;
var binaryCaches = {};
function getBinaryCache(webGLVersion) {
if (webGLVersion in binaryCaches) {
return binaryCaches[webGLVersion];
}
binaryCaches[webGLVersion] = {};
return binaryCaches[webGLVersion];
} // Empirically determined constant used to determine size threshold for handing
// off execution to the CPU.
var CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to decide the number of MB on GPU
// before we warn about high memory use. The MB are this constant * screen area
// * dpi / 1024 / 1024.
var BEFORE_PAGING_CONSTANT = 600;
function numMBBeforeWarning() {
if (env().global.screen == null) {
return 1024; // 1 GB.
}
return env().global.screen.height * env().global.screen.width * window.devicePixelRatio * BEFORE_PAGING_CONSTANT / 1024 / 1024;
}
var MathBackendWebGL = /*#__PURE__*/function (_KernelBackend) {
_inheritsLoose(MathBackendWebGL, _KernelBackend);
function MathBackendWebGL(gpgpu) {
var _this;
_this = _KernelBackend.call(this) || this; // Maps data ids that have a pending read operation, to list of subscribers.
_this.pendingRead = new WeakMap(); // List of data ids that are scheduled for disposal, but are waiting on a
// pending read operation.
_this.pendingDisposal = new WeakSet(); // Used to count the number of 'shallow' sliced tensors that point to the
// same data id.
_this.dataRefCount = new WeakMap();
_this.numBytesInGPU = 0; // Accumulated time spent (including blocking) in uploading data to webgl.
_this.uploadWaitMs = 0; // Accumulated time spent (including blocking in downloading data from webgl.
_this.downloadWaitMs = 0; // record the last manual GL Flush time.
_this.lastGlFlushTime = 0;
_this.warnedAboutMemory = false;
_this.pendingDeletes = 0;
_this.disposed = false;
if (!env().getBool('HAS_WEBGL')) {
throw new Error('WebGL is not supported on this device');
}
if (gpgpu == null) {
var gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
_this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
_this.gpgpu = new GPGPUContext(gl);
_this.canvas = gl.canvas;
_this.gpgpuCreatedLocally = true;
} else {
_this.gpgpu = gpgpu;
_this.binaryCache = {};
_this.gpgpuCreatedLocally = false;
_this.canvas = gpgpu.gl.canvas;
}
_this.textureManager = new TextureManager(_this.gpgpu);
_this.numMBBeforeWarning = numMBBeforeWarning();
_this.texData = new DataStorage(_assertThisInitialized(_this), engine());
return _this;
}
var _proto = MathBackendWebGL.prototype;
_proto.nextDataId = function nextDataId() {
return MathBackendWebGL.nextDataId++;
};
_proto.numDataIds = function numDataIds() {
return this.texData.numDataIds() - this.pendingDeletes;
};
_proto.write = function write(values, shape, dtype) {
if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') || env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64' && values != null) {
throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag).");
}
var dataId = {
id: this.nextDataId()
};
this.texData.set(dataId, {
shape: shape,
dtype: dtype,
values: values,
usage: TextureUsage.UPLOAD,
refCount: 1
});
return dataId;
}
/** Return refCount of a `TensorData`. */
;
_proto.refCount = function refCount(dataId) {
if (this.texData.has(dataId)) {
var tensorData = this.texData.get(dataId);
return tensorData.refCount;
}
return 0;
}
/** Increase refCount of a `TextureData`. */
;
_proto.incRef = function incRef(dataId) {
var texData = this.texData.get(dataId);
texData.refCount++;
}
/** Decrease refCount of a `TextureData`. */
;
_proto.decRef = function decRef(dataId) {
if (this.texData.has(dataId)) {
var texData = this.texData.get(dataId);
texData.refCount--;
}
};
_proto.move = function move(dataId, values, shape, dtype, refCount) {
if (env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64') {
throw new Error("Cannot write to a complex64 dtype. " + "Please use tf.complex(real, imag).");
}
this.texData.set(dataId, {
shape: shape,
dtype: dtype,
values: values,
usage: TextureUsage.UPLOAD,
refCount: refCount
});
};
_proto.disposeIntermediateTensorInfo = function disposeIntermediateTensorInfo(tensorInfo) {
this.disposeData(tensorInfo.dataId);
};
_proto.readSync = function readSync(dataId) {
var texData = this.texData.get(dataId);
var values = texData.values,
dtype = texData.dtype,
complexTensorInfos = texData.complexTensorInfos,
slice = texData.slice,
shape = texData.shape,
isPacked = texData.isPacked; // The presence of `slice` indicates this tensor is a shallow slice of a
// different tensor, and is using that original tensor's texture. Run
// `clone` in order to copy that texture and read from it.
if (slice != null) {
var program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
} else {
program = new UnaryOpProgram(shape, CLONE);
}
var res = this.runWebGLProgram(program, [{
dataId: dataId,
shape: shape,
dtype: dtype
}], dtype);
var data = this.readSync(res.dataId);
this.disposeIntermediateTensorInfo(res);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === 'string') {
return values;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = now();
}
var result;
if (dtype === 'complex64') {
var realValues = this.readSync(complexTensorInfos.real.dataId);
var imagValues = this.readSync(complexTensorInfos.imag.dataId);
result = mergeRealAndImagArrays(realValues, imagValues);
} else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
};
_proto.read = /*#__PURE__*/function () {
var _read = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee(dataId) {
var _subscribers, texData, values, shape, slice, dtype, complexTensorInfos, isPacked, program, res, data, buffer, tmpDownloadTarget, _this$gpgpu, tmpData, vals, ps, realValues, imagValues, size, gl, dTypeVals, subscribers;
return regeneratorRuntime.wrap(function _callee$(_context) {
while (1) {
switch (_context.prev = _context.next) {
case 0:
if (!this.pendingRead.has(dataId)) {
_context.next = 3;
break;
}
_subscribers = this.pendingRead.get(dataId);
return _context.abrupt("return", new Promise(function (resolve) {
return _subscribers.push(resolve);
}));
case 3:
texData = this.texData.get(dataId);
values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, isPacked = texData.isPacked; // The presence of `slice` indicates this tensor is a shallow slice of a
// different tensor, and is using that original tensor's texture. Run
// `clone` in order to copy that texture and read from it.
if (!(slice != null)) {
_context.next = 11;
break;
}
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
} else {
program = new UnaryOpProgram(shape, CLONE);
}
res = this.runWebGLProgram(program, [{
dataId: dataId,
shape: shape,
dtype: dtype
}], dtype);
data = this.read(res.dataId);
this.disposeIntermediateTensorInfo(res);
return _context.abrupt("return", data);
case 11:
if (!(values != null)) {
_context.next = 13;
break;
}
return _context.abrupt("return", this.convertAndCacheOnCPU(dataId));
case 13:
if (!(!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && env().getNumber('WEBGL_VERSION') === 2)) {
_context.next = 15;
break;
}
throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " + "WEBGL_VERSION=2 not yet supported.");
case 15:
buffer = null;
if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
// Possibly copy the texture into a buffer before inserting a fence.
tmpDownloadTarget = this.decode(dataId);
tmpData = this.texData.get(tmpDownloadTarget.dataId);
buffer = (_this$gpgpu = this.gpgpu).createBufferFromTexture.apply(_this$gpgpu, [tmpData.texture].concat(getDenseTexShape(shape)));
}
this.pendingRead.set(dataId, []);
if (!(dtype !== 'complex64')) {
_context.next = 21;
break;
}
_context.next = 21;
return this.gpgpu.createAndWaitForFence();
case 21:
if (!(dtype === 'complex64')) {
_context.next = 30;
break;
}
_context.next = 24;
return Promise.all([this.read(complexTensorInfos.real.dataId), this.read(complexTensorInfos.imag.dataId)]);
case 24:
ps = _context.sent;
realValues = ps[0];
imagValues = ps[1];
vals = mergeRealAndImagArrays(realValues, imagValues);
_context.next = 31;
break;
case 30:
if (buffer == null) {
vals = this.getValuesFromTexture(dataId);
} else {
size = sizeFromShape(shape);
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
}
case 31:
if (tmpDownloadTarget != null) {
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
}
if (buffer != null) {
gl = this.gpgpu.gl;
callAndCheck(gl, function () {
return gl.deleteBuffer(buffer);
});
}
dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId); // Notify all pending reads.
subscribers.forEach(function (resolve) {
return resolve(dTypeVals);
});
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
if (this.disposeData(dataId)) {
engine().removeDataId(dataId, this);
}
this.pendingDeletes--;
}
return _context.abrupt("return", dTypeVals);
case 39:
case "end":
return _context.stop();
}
}
}, _callee, this);
}));
function read(_x) {
return _read.apply(this, arguments);
}
return read;
}();
_proto.bufferSync = function bufferSync(t) {
var data = this.readSync(t.dataId);
var decodedData = data;
if (t.dtype === 'string') {
try {
// Decode the bytes into string.
decodedData = data.map(function (d) {
return decodeString(d);
});
} catch (_a) {
throw new Error('Failed to decode encoded string bytes into utf-8');
}
}
return buffer(t.shape, t.dtype, decodedData);
};
_proto.checkNumericalProblems = function checkNumericalProblems(values) {
if (values == null) {
return;
}
for (var i = 0; i < values.length; i++) {
var num = values[i];
if (!canBeRepresented(num)) {
if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
throw Error("The value " + num + " cannot be represented with your " + "current settings. Consider enabling float32 rendering: " + "'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'");
}
throw Error("The value " + num + " cannot be represented on this device.");
}
}
};
_proto.getValuesFromTexture = function getValuesFromTexture(dataId) {
var _this$texData$get = this.texData.get(dataId),
shape = _this$texData$get.shape,
dtype = _this$texData$get.dtype,
isPacked = _this$texData$get.isPacked;
var size = sizeFromShape(shape);
if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
var _this$gpgpu2;
var tmpTarget = this.decode(dataId);
var _tmpData = this.texData.get(tmpTarget.dataId);
var _vals = (_this$gpgpu2 = this.gpgpu).downloadMatrixFromPackedTexture.apply(_this$gpgpu2, [_tmpData.texture].concat(getDenseTexShape(shape))).subarray(0, size);
this.disposeIntermediateTensorInfo(tmpTarget);
return _vals;
}
var shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
var program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape);
var output = this.runWebGLProgram(program, [{
shape: outputShape,
dtype: dtype,
dataId: dataId
}], 'float32');
var tmpData = this.texData.get(output.dataId);
var vals = this.gpgpu.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1]).subarray(0, size);
this.disposeIntermediateTensorInfo(output);
return vals;
};
_proto.timerAvailable = function timerAvailable() {
return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
};
_proto.time = /*#__PURE__*/function () {
var _time = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee2(f) {
var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, res, kernelMs;
return regeneratorRuntime.wrap(function _callee2$(_context2) {
while (1) {
switch (_context2.prev = _context2.next) {
case 0:
oldActiveTimers = this.activeTimers;
newActiveTimers = [];
outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
} else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f(); // needing to split these up because util.flatten only accepts certain types
flattenedActiveTimerQueries = flatten(this.activeTimers.map(function (d) {
return d.query;
})).filter(function (d) {
return d != null;
});
flattenedActiveTimerNames = flatten(this.activeTimers.map(function (d) {
return d.name;
})).filter(function (d) {
return d != null;
});
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: null,
wallMs: null // will be filled by the engine
};
if (!(env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) {
_context2.next = 19;
break;
}
_context2.next = 14;
return Promise.all(flattenedActiveTimerQueries);
case 14:
kernelMs = _context2.sent;
res['kernelMs'] = sum(kernelMs);
res['getExtraProfileInfo'] = function () {
return kernelMs.map(function (d, i) {
return {
name: flattenedActiveTimerNames[i],
ms: d
};
}).map(function (d) {
return d.name + ": " + d.ms;
}).join(', ');
};
_context2.next = 20;
break;
case 19:
res['kernelMs'] = {
error: 'WebGL query timers are not supported in this environment.'
};
case 20:
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return _context2.abrupt("return", res);
case 23:
case "end":
return _context2.stop();
}
}
}, _callee2, this);
}));
function time(_x2) {
return _time.apply(this, arguments);
}
return time;
}();
_proto.memory = function memory() {
return {
unreliable: false,
numBytesInGPU: this.numBytesInGPU,
numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
numBytesInGPUFree: this.textureManager.numBytesFree
};
};
_proto.startTimer = function startTimer() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.beginQuery();
}
return {
startMs: now(),
endMs: null
};
};
_proto.endTimer = function endTimer(query) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
this.gpgpu.endQuery();
return query;
}
query.endMs = now();
return query;
};
_proto.getQueryTime = /*#__PURE__*/function () {
var _getQueryTime = _asyncToGenerator( /*#__PURE__*/regeneratorRuntime.mark(function _callee3(query) {
var timerQuery;
return regeneratorRuntime.wrap(function _callee3$(_context3) {
while (1) {
switch (_context3.prev = _context3.next) {
case 0:
if (!(env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) {
_context3.next = 2;
break;
}
return _context3.abrupt("return", this.gpgpu.waitForQueryAndGetTime(query));
case 2:
timerQuery = query;
return _context3.abrupt("return", timerQuery.endMs - timerQuery.startMs);
case 4:
case "end":
return _context3.stop();
}
}
}, _callee3, this);
}));
function getQueryTime(_x3) {
return _getQueryTime.apply(this, arguments);
}
return getQueryTime;
}()
/**
* Decrease the RefCount on the dataId and dispose the memory if the dataId
* has 0 refCount. If there are pending read on the data, the disposal would
* added to the pending delete queue. Return true if the dataId is removed
* from backend or the backend does not contain the dataId, false if the
* dataId is not removed. Memory may or may not be released even when dataId
* is removed, which also depends on dataRefCount, see `releaseGPU`.
* @param dataId
* @oaram force Optional, remove the data regardless of refCount
*/
;
_proto.disposeData = function disposeData(dataId, force) {
if (force === void 0) {
force = false;
}
if (this.pendingDisposal.has(dataId)) {
return false;
} // No-op if already disposed.
if (!this.texData.has(dataId)) {
return true;
} // if force flag is set, change refCount to 0, this would ensure disposal
// when added to the pendingDisposal queue. Memory may or may not be
// released, which also depends on dataRefCount, see `releaseGPU`.
if (force) {
this.texData.get(dataId).refCount = 0;
} else {
this.texData.get(dataId).refCount--;
}
if (!force && this.texData.get(dataId).refCount > 0) {
return false;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
this.pendingDeletes++;
return false;
}
this.releaseGPUData(dataId);
var _this$texData$get2 = this.texData.get(dataId),
complexTensorInfos = _this$texData$get2.complexTensorInfos;
if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId, force);
this.disposeData(complexTensorInfos.imag.dataId, force);
}
this.texData.delete(dataId);
return true;
};
_proto.releaseGPUData = function releaseGPUData(dataId) {
var _this$texData$get3 = this.texData.get(dataId),
texture = _this$texData$get3.texture,
dtype = _this$texData$get3.dtype,
texShape = _this$texData$get3.texShape,
usage = _this$texData$get3.usage,
isPacked = _this$texData$get3.isPacked,
slice = _this$texData$get3.slice;
var key = slice && slice.origDataId || dataId;
var refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
} else {
this.dataRefCount.delete(key);
if (texture != null) {
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
}
}
var texData = this.texData.get(dataId);
texData.texture = null;
texData.texShape = null;
texData.isPacked = false;
texData.slice = null;
};
_proto.getTexture = function getTexture(dataId) {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture;
}
/**
* Returns internal information for the specific data bucket. Used in unit
* tests.
*/
;
_proto.getDataInfo = function getDataInfo(dataId) {
return this.texData.get(dataId);
}
/*
Tests whether all the inputs to an op are small and on the CPU. This heuristic
determines when it would be faster to execute a kernel on the CPU. WebGL
kernels opt into running this check and forwarding when appropriate.
TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
sustainable strategy for optimizing backend execution of ops.
*/
;
_proto.shouldExecuteOnCPU = function shouldExecuteOnCPU(inputs, sizeThreshold) {
var _this2 = this;
if (sizeThreshold === void 0) {
sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD;
}
return env().getBool('WEBGL_CPU_FORWARD') && inputs.every(function (input) {
return _this2.texData.get(input.dataId).texture == null && sizeFromShape(input.shape) < sizeThreshold;
});
};
_proto.getGPGPUContext = function getGPGPUContext() {
return this.gpgpu;
};
_proto.where = function where(condition) {
warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead');
var condVals = condition.dataSync();
return whereImpl$2(condition.shape, condVals);
};
_proto.packedUnaryOp = function packedUnaryOp(x, op, dtype) {
var program = new UnaryOpPackedProgram(x.shape, op);
var outInfo = this.compileAndRun(program, [x], dtype);
return engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
} // TODO(msoulanille) remove this once the backend has been modularized
// a copy is needed here to break a circular dependency.
// Also remove the op from unary_op.
;
_proto.abs = function abs(x) {
// TODO: handle cases when x is complex.
if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
var outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
return this.packedUnaryOp(x, ABS, x.dtype);
}
var program = new UnaryOpProgram(x.shape, ABS);
var outInfo = this.compileAndRun(program, [x]);
return engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
};
_proto.makeTensorInfo = function makeTensorInfo(shape, dtype, values) {
var dataId;
if (dtype === 'string' && values != null && values.length > 0 && isString(values[0])) {
var encodedValues = values.map(function (d) {
return encodeString(d);
});
dataId = this.write(encodedValues, shape, dtype);
} else {
dataId = this.write(values, shape, dtype);
}
this.texData.get(dataId).usage = null;
return {
dataId: dataId,
shape: shape,
dtype: dtype
};
};
_proto.makeOutput = function makeOutput(shape, dtype, values) {
var _this$makeTensorInfo = this.makeTensorInfo(shape, dtype, values),
dataId = _this$makeTensorInfo.dataId;
return engine().makeTensorFromDataId(dataId, shape, dtype, this);
};
_proto.unpackTensor = function unpackTensor(input) {
var program = new UnpackProgram(input.shape);
return this.runWebGLProgram(program, [input], input.dtype);
};
_proto.packTensor = function packTensor(input) {
var program = new PackProgram(input.shape);
var preventEagerUnpackingOutput = true;
return this.runWebGLProgram(program, [input], input.dtype, null
/* customUniformValues */
, preventEagerUnpackingOutput);
};
_proto.packedReshape = function packedReshape(input, afterShape) {
var input3DShape = [getBatchDim(input.shape)].concat(getRowsCols(input.shape));
var input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
var afterShapeAs3D = [getBatchDim(afterShape)].concat(getRowsCols(afterShape));
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
var preventEagerUnpackingOfOutput = true;
var customValues = [input3DShape];
var output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
return {
dataId: output.dataId,
shape: afterShape,
dtype: output.dtype
};
};
_proto.decode = function decode(dataId) {
var texData = this.texData.get(dataId);
var isPacked = texData.isPacked,
shape = texData.shape,
dtype = texData.dtype;
var shapeAs3D = getShapeAs3D(shape);
var program;
var denseTexShape = getDenseTexShape(shapeAs3D);
if (isPacked) {
program = new DecodeMatrixPackedProgram(shapeAs3D);
} else {
program = new DecodeMatrixProgram(shapeAs3D);
}
var preventEagerUnpackingOfOutput = true;
var customValues = [denseTexShape];
var out = this.runWebGLProgram(program, [{
shape: shapeAs3D,
dtype: dtype,
dataId: dataId
}], dtype, customValues, preventEagerUnpackingOfOutput);
return {
dtype: dtype,
shape: shape,
dataId: out.dataId
};
};
_proto.runWebGLProgram = function runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput) {
var _this3 = this;
if (preventEagerUnpackingOfOutput === void 0) {
preventEagerUnpackingOfOutput = false;
}
var output = this.makeTensorInfo(program.outputShape, outputDtype);
var outData = this.texData.get(output.dataId);
if (program.packedOutput) {
outData.isPacked = true;
}
if (program.outPackingScheme === PackingScheme.DENSE) {
var texelShape = getDenseTexShape(program.outputShape); // For a densely packed output, we explicitly set texShape
// so it doesn't get assigned later according to our typical packing
// scheme wherein a single texel can only contain values from adjacent
// rows/cols.
outData.texShape = texelShape.map(function (d) {
return d * 2;
});
}
if (program.outTexUsage != null) {
outData.usage = program.outTexUsage;
}
if (sizeFromShape(output.shape) === 0) {
// Short-circuit the computation since the result is empty (has 0 in its
// shape).
outData.values = getTypedArrayFromDType(output.dtype, 0);
return output;
}
var dataToDispose = [];
var inputsData = inputs.map(function (input) {
if (input.dtype === 'complex64') {
throw new Error("GPGPUProgram does not support complex64 input. For complex64 " + "dtypes, please separate the program into real and imaginary " + "parts.");
}
var texData = _this3.texData.get(input.dataId);
if (texData.texture == null) {
if (!program.packedInputs && sizeFromShape(input.shape) <= env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
// Upload small tensors that live on the CPU as uniforms, not as
// textures. Do this only when the environment supports 32bit floats
// due to problems when comparing 16bit floats with 32bit floats.
// TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
// possible for packed shaders to sample from uniforms.
return {
shape: input.shape,
texData: null,
isUniform: true,
uniformValues: texData.values
};
} // This ensures that if a packed program's inputs have not yet been
// uploaded to the GPU, they get uploaded as packed right off the bat.
if (program.packedInputs) {
texData.isPacked = true;
texData.shape = input.shape;
}
} else if (!!texData.isPacked !== !!program.packedInputs) {
input = texData.isPacked ? _this3.unpackTensor(input) : _this3.packTensor(input);
dataToDispose.push(input);
texData = _this3.texData.get(input.dataId);
} else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) {
// This is a special case where a texture exists for a tensor
// but the shapes are incompatible (due to packing constraints) because
// the tensor did not have a chance to go through the packed reshape
// shader. This only happens when we reshape the *same* tensor to form
// *distinct* inputs to an op, e.g. dotting a vector with itself. This
// case will disappear once packed uploading is the default.
var savedInput = input;
var targetShape = input.shape;
input.shape = texData.shape;
input = _this3.packedReshape(input, targetShape);
dataToDispose.push(input);
texData = _this3.texData.get(input.dataId);
savedInput.shape = targetShape;
}
_this3.uploadToGPU(input.dataId);
return {
shape: input.shape,
texData: texData,
isUniform: false
};
});
this.uploadToGPU(output.dataId);
var outputData = {
shape: output.shape,
texData: outData,
isUniform: false
};
var key = makeShaderKey(program, inputsData, outputData);
var binary = this.getAndSaveBinary(key, function () {
return compileProgram(_this3.gpgpu, program, inputsData, outputData);
});
var shouldTimeProgram = this.activeTimers != null;
var query;
if (shouldTimeProgram) {
query = this.startTimer();
}
runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
dataToDispose.forEach(function (info) {
return _this3.disposeIntermediateTensorInfo(info);
});
if (shouldTimeProgram) {
query = this.endTimer(query);
this.activeTimers.push({
name: program.constructor.name,
query: this.getQueryTime(query)
});
}
var glFlushThreshold = env().get('WEBGL_FLUSH_THRESHOLD'); // Manually GL flush requested
if (glFlushThreshold > 0) {
var time = now();
if (time - this.lastGlFlushTime > glFlushThreshold) {
this.gpgpu.gl.flush();
this.lastGlFlushTime = time;
}
}
if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && preventEagerUnpackingOfOutput === false) {
var unpacked = this.unpackTensor(output);
this.disposeIntermediateTensorInfo(output);
return unpacked;
}
return output;
};
_proto.compileAndRun = function compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput) {
if (preventEagerUnpackingOfOutput === void 0) {
preventEagerUnpackingOfOutput = false;
}
outputDtype = outputDtype || inputs[0].dtype;
var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
return outInfo;
};
_proto.getAndSaveBinary = function getAndSaveBinary(key, getBinary) {
if (!(key in this.binaryCache)) {
this.binaryCache[key] = getBinary();
}
return this.binaryCache[key];
};
_proto.getTextureManager = function getTextureManager() {
return this.textureManager;
};
_proto.dispose = function dispose() {
var _this4 = this;
if (this.disposed) {
return;
} // Avoid disposing the compiled webgl programs during unit testing because
// it slows down test execution.
if (!env().getBool('IS_TEST')) {
var allKeys = Object.keys(this.binaryCache);
allKeys.forEach(function (key) {
_this4.gpgpu.deleteProgram(_this4.binaryCache[key].webGLProgram);
delete _this4.binaryCache[key];
});
}
this.textureManager.dispose();
if (this.canvas != null && typeof HTMLCanvasElement !== 'undefined' && this.canvas instanceof HTMLCanvasElement) {
this.canvas.remove();
} else {
this.canvas = null;
}
if (this.gpgpuCreatedLocally) {
this.gpgpu.program = null;
this.gpgpu.dispose();
}
this.disposed = true;
};
_proto.floatPrecision = function floatPrecision() {
var _this5 = this;
if (this.floatPrecisionValue == null) {
this.floatPrecisionValue = tidy(function () {
if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
// Momentarily switching DEBUG flag to false so we don't throw an
// error trying to upload a small value.
var debugFlag = env().getBool('DEBUG');
env().set('DEBUG', false);
var underflowCheckValue = _this5.abs(scalar(1e-8)).dataSync()[0];
env().set('DEBUG', debugFlag);
if (underflowCheckValue > 0) {
return 32;
}
}
return 16;
});
}
return this.floatPrecisionValue;
}
/** Returns the smallest representable number. */
;
_proto.epsilon = function epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
};
_proto.uploadToGPU = function uploadToGPU(dataId) {
var texData = this.texData.get(dataId);
var shape = texData.shape,
dtype = texData.dtype,
values = texData.values,
texture = texData.texture,
usage = texData.usage,
isPacked = texData.isPacked;
if (texture != null) {
// Array is already on GPU. No-op.
return;
}
var shouldTimeProgram = this.activeTimers != null;
var start;
if (shouldTimeProgram) {
start = now();
}
var texShape = texData.texShape;
if (texShape == null) {
texShape = getTextureShapeFromLogicalShape(shape, isPacked);
texData.texShape = texShape;
}
if (values != null) {
var shapeAs3D = getShapeAs3D(shape);
var program;
var width = texShape[1],
height = texShape[0];
var isByteArray = values instanceof Uint8Array;
if (isPacked) {
var _tex_util$getPackedMa = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
width = _tex_util$getPackedMa[0];
height = _tex_util$getPackedMa[1];
program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
} else {
program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
}
var tempDenseInputHandle = this.makeTensorInfo([height, width], dtype);
if (isByteArray) {
this.texData.get(tempDenseInputHandle.dataId).usage = TextureUsage.PIXELS;
} else {
this.texData.get(tempDenseInputHandle.dataId).usage = TextureUsage.UPLOAD;
}
this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
var customValues = [[height, width]]; // We want the output to remain packed regardless of the value of
// WEBGL_PACK.
var preventEagerUnpacking = true;
var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking); // Have the original texture assume the identity of the encoded output.
var outputTexData = this.texData.get(encodedOutputTarget.dataId);
texData.texture = outputTexData.texture;
texData.texShape = outputTexData.texShape;
texData.isPacked = outputTexData.isPacked;
texData.usage = outputTexData.usage;
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
this.texData.delete(encodedOutputTarget.dataId); // Once uploaded, don't store the values on cpu.
texData.values = null;
if (shouldTimeProgram) {
this.uploadWaitMs += now() - start;
}
} else {
var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
texData.texture = newTexture;
}
};
_proto.convertAndCacheOnCPU = function convertAndCacheOnCPU(dataId, float32Values) {
var texData = this.texData.get(dataId);
var dtype = texData.dtype;
this.releaseGPUData(dataId);
if (float32Values != null) {
texData.values = float32ToTypedArray(float32Values, dtype);
}
return texData.values;
};
_proto.acquireTexture = function acquireTexture(texShape, texType, dtype, isPacked) {
this.numBytesInGPU += this.computeBytes(texShape, dtype);
if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
this.warnedAboutMemory = true;
console.warn("High memory usage in GPU: " + mb + " MB, " + "most likely due to a memory leak");
}
return this.textureManager.acquireTexture(texShape, texType, isPacked);
};
_proto.computeBytes = function computeBytes(shape, dtype) {
return shape[0] * shape[1] * bytesPerElement(dtype);
};
return MathBackendWebGL;
}(KernelBackend);
MathBackendWebGL.nextDataId = 0;
function float32ToTypedArray(a, dtype) {
if (dtype === 'float32' || dtype === 'complex64') {
return a;
} else if (dtype === 'int32' || dtype === 'bool') {
var result = dtype === 'int32' ? new Int32Array(a.length) : new Uint8Array(a.length);
for (var i = 0; i < result.length; ++i) {
result[i] = Math.round(a[i]);
}
return result;
} else {
throw new Error("Unknown dtype " + dtype);
}
}
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$6 = '3.9.0';
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Enforce use of half precision textures if available on the platform.
*
* @doc {heading: 'Environment', namespace: 'webgl'}
*/
function forceHalfFloat() {
env().set('WEBGL_FORCE_F16_TEXTURES', true);
}
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
if (isBrowser()) {
registerBackend('webgl', function () {
return new MathBackendWebGL();
}, 2
/* priority */
);
} // Export webgl utilities
var webgl = {
forceHalfFloat: forceHalfFloat
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET$1 = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
var SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
var BinaryOpProgram = function BinaryOpProgram(op, aShape, bShape) {
this.variableNames = ['A', 'B'];
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET$2 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
var ELU_DER = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
var NOT_EQUAL = "\n return vec4(notEqual(a, b));\n";
var BinaryOpPackedProgram = function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) {
if (checkOutOfBounds === void 0) {
checkOutOfBounds = false;
}
this.variableNames = ['A', 'B'];
this.supportsBroadcasting = true;
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
var rank = this.outputShape.length;
this.enableShapeUniforms = useShapeUniforms(rank);
var checkOutOfBoundsString = '';
if (checkOutOfBounds) {
if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
} else {
var dtype = getCoordsDataType(rank);
checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n ";
if (rank === 1) {
if (this.enableShapeUniforms) {
checkOutOfBoundsString += "\n result.y = (coords + 1) >= outShape ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
} else {
checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
}
} else {
var channels = getChannels('coords', rank);
if (this.enableShapeUniforms) {
checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= outShape[" + rank + " - 2];\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= outShape[" + rank + " - 1];\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ";
} else {
checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ";
}
}
}
}
this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function identity$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
backend.incRef(x.dataId);
return {
dataId: x.dataId,
shape: x.shape,
dtype: x.dtype
};
}
var identityConfig$1 = {
kernelName: Identity,
backendName: 'webgl',
kernelFunc: identity$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* In WebGL data is stored in GPU textures which can't be efficiently copied, so
* complex tensors share data with their real and imaginary components. Complex
* tensors' reference to the components is tracked by refCount on the individual
* component. The refCounts are increased by the identity call.
*
* When a complex tensor is disposed, it will reduce the refCount on the
* components by calling disposeData on each.
*/
function complex$2(args) {
var inputs = args.inputs,
backend = args.backend;
var real = inputs.real,
imag = inputs.imag;
var complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
var complex = backend.texData.get(complexInfo.dataId);
var realTensorInfo = identity$2({
inputs: {
x: real
},
backend: backend
});
var imagTensorInfo = identity$2({
inputs: {
x: imag
},
backend: backend
});
complex.complexTensorInfos = {
real: realTensorInfo,
imag: imagTensorInfo
};
return complexInfo;
}
var complexConfig$1 = {
kernelName: Complex,
backendName: 'webgl',
kernelFunc: complex$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LEAKYRELU = "return (a < 0.) ? b * a : a;";
var LEAKYRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
function leakyRelu$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var alpha = attrs.alpha;
var $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) : new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
var result = backend.runWebGLProgram(program, [x, $alpha], x.dtype);
backend.disposeIntermediateTensorInfo($alpha);
return result;
}
var leakyReluConfig$1 = {
kernelName: LeakyRelu,
backendName: 'webgl',
kernelFunc: leakyRelu$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PRELU = "return (a < 0.) ? b * a : a;";
var PRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
function prelu$3(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x,
alpha = inputs.alpha;
var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape);
return backend.runWebGLProgram(program, [x, alpha], x.dtype);
}
var preluConfig$1 = {
kernelName: Prelu,
backendName: 'webgl',
kernelFunc: prelu$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CHECK_NAN_SNIPPET_UNARY = "if (isnan(x)) return x;";
var CHECK_NAN_SNIPPET_BINARY = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
var CHECK_NAN_SNIPPET_BINARY_PACKED = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
/**
* Template that creates a `KernelFunc` for unary ops.
* @param opSnippet Op snippet to create `UnaryOpProgram`.
* @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the first input. This is mainly used in
* comparison kernels, such as Equal, Less, Greater, etc.
*/
function unaryKernelFunc$1(_ref) {
var opSnippet = _ref.opSnippet,
packedOpSnippet = _ref.packedOpSnippet,
cpuKernelImpl = _ref.cpuKernelImpl,
dtype = _ref.dtype;
return function (_ref2) {
var inputs = _ref2.inputs,
backend = _ref2.backend;
var x = inputs.x;
var webglBackend = backend;
var $dtype = dtype || x.dtype;
if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
var xData = webglBackend.texData.get(x.dataId);
var outValues = cpuKernelImpl(xData.values, $dtype);
return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
}
var shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
var program;
if (shouldUsePackedProgram) {
program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
} else {
program = new UnaryOpProgram(x.shape, opSnippet);
}
return webglBackend.runWebGLProgram(program, [x], $dtype);
};
}
/**
* Template that creates a `KernelFunc` for binary ops.
* @param opSnippet Op snippet to create `BinaryOpProgram`.
* @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
* @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
* when creating BinaryOpPackedProgram.
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
* result has the same dtype as the first input. This is mainly used in
* comparison kernels, such as Equal, Less, Greater, etc.
*/
function binaryKernelFunc$1(_ref3) {
var opSnippet = _ref3.opSnippet,
packedOpSnippet = _ref3.packedOpSnippet,
_ref3$checkOutOfBound = _ref3.checkOutOfBounds,
checkOutOfBounds = _ref3$checkOutOfBound === void 0 ? false : _ref3$checkOutOfBound,
_ref3$supportsComplex = _ref3.supportsComplex,
supportsComplex = _ref3$supportsComplex === void 0 ? false : _ref3$supportsComplex,
cpuKernelImpl = _ref3.cpuKernelImpl,
dtype = _ref3.dtype;
return function (_ref4) {
var inputs = _ref4.inputs,
backend = _ref4.backend;
var a = inputs.a,
b = inputs.b;
var webglBackend = backend;
if (supportsComplex && a.dtype === 'complex64') {
var aData = webglBackend.texData.get(a.dataId);
var bData = webglBackend.texData.get(b.dataId);
var _map = [[aData.complexTensorInfos.real, bData.complexTensorInfos.real], [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]].map(function (complexParts) {
var aPart = complexParts[0],
bPart = complexParts[1];
var aHandle = {
dataId: aPart.dataId,
dtype: aPart.dtype,
shape: a.shape
};
var bHandle = {
dataId: bPart.dataId,
dtype: bPart.dtype,
shape: b.shape
};
var program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
}),
real = _map[0],
imag = _map[1];
var complexOutput = complex$2({
inputs: {
real: real,
imag: imag
},
backend: webglBackend
});
webglBackend.disposeIntermediateTensorInfo(real);
webglBackend.disposeIntermediateTensorInfo(imag); // TODO(annxingyuan): Implement CPU forwarding for complex inputs.
return complexOutput;
}
var $dtype = dtype || upcastType(a.dtype, b.dtype);
if ((a.dtype === 'string' || b.dtype === 'string' || webglBackend.shouldExecuteOnCPU([a, b])) && cpuKernelImpl != null) {
var aVals = webglBackend.texData.get(a.dataId).values;
var bVals = webglBackend.texData.get(b.dataId).values;
var decodedAVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any
fromUint8ToStringArray(aVals) : aVals;
var decodedBVals = a.dtype === 'string' ? // tslint:disable-next-line: no-any
fromUint8ToStringArray(bVals) : bVals;
var _cpuKernelImpl = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype),
outValues = _cpuKernelImpl[0],
outShape = _cpuKernelImpl[1];
var out = webglBackend.makeTensorInfo(outShape, $dtype);
var outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
var shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') && packedOpSnippet != null;
var program;
if (shouldUsePackedProgram) {
program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
} else {
program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
}
return webglBackend.runWebGLProgram(program, [a, b], $dtype);
};
}
function mapActivationToShaderProgram(activation, packed) {
if (packed === void 0) {
packed = false;
}
if (activation === 'linear') {
if (packed) {
return LINEAR$1;
}
return LINEAR;
} else if (activation === 'relu') {
if (packed) {
return RELU$1;
}
return RELU;
} else if (activation === 'elu') {
if (packed) {
return ELU$2;
}
return ELU$1;
} else if (activation === 'relu6') {
if (packed) {
return RELU6$1;
}
return RELU6;
} else if (activation === 'prelu') {
if (packed) {
return PRELU_PACKED;
}
return PRELU;
} else if (activation === 'leakyrelu') {
if (packed) {
return LEAKYRELU_PACKED;
}
return LEAKYRELU;
} else if (activation === 'sigmoid') {
if (packed) {
return SIGMOID$1;
}
return SIGMOID;
}
throw new Error("Activation " + activation + " has not been implemented for the WebGL backend.");
}
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MatMulPackedProgram = function MatMulPackedProgram(aShape, bShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation, hasLeakyreluActivation) {
if (transposeA === void 0) {
transposeA = false;
}
if (transposeB === void 0) {
transposeB = false;
}
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
if (hasLeakyreluActivation === void 0) {
hasLeakyreluActivation = false;
}
this.variableNames = ['matrixA', 'matrixB'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var sharedDim = transposeA ? aShape[1] : aShape[2];
var sharedDimensionPacked = Math.ceil(sharedDim / 2);
var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
var activationSnippet = '',
applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else if (hasLeakyreluActivation) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyreluActivation) {
this.variableNames.push('leakyreluAlpha');
}
var batchASnippet = 'rc.x';
var batchBSnippet = 'rc.x';
if (aShape[0] < bShape[0]) {
batchASnippet = "int(min(float(rc.x), " + (aShape[0] - 1) + ".))";
} else if (bShape[0] < aShape[0]) {
batchBSnippet = "int(min(float(rc.x), " + (bShape[0] - 1) + ".))";
}
this.userCode = "\n " + activationSnippet + "\n // Don't use uniform for sharedDimensionPacked for performance.\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n int batchA = " + batchASnippet + ";\n int batchB = " + batchBSnippet + ";\n vec4 a = getMatrixA(batchA, " + aSample + ");\n vec4 b = getMatrixB(batchB, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
// Yr = ArBr - AB
// Yi = ArBi + AiBr
var COMPLEX_MULTIPLY = {
REAL: 'return areal * breal - aimag * bimag;',
IMAG: 'return areal * bimag + aimag * breal;'
};
var BinaryOpComplexProgram = function BinaryOpComplexProgram(op, aShape, bShape) {
this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MUL = 'return a * b;';
function multiply$4(args) {
var inputs = args.inputs,
backend = args.backend;
var a = inputs.a,
b = inputs.b;
var dtype = upcastType(a.dtype, b.dtype);
if (a.dtype === 'complex64') {
var aData = backend.texData.get(a.dataId);
var bData = backend.texData.get(b.dataId);
var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
var _inputs = [{
dataId: aData.complexTensorInfos.real.dataId,
dtype: aData.complexTensorInfos.real.dtype,
shape: a.shape
}, {
dataId: aData.complexTensorInfos.imag.dataId,
dtype: aData.complexTensorInfos.imag.dtype,
shape: a.shape
}, {
dataId: bData.complexTensorInfos.real.dataId,
dtype: bData.complexTensorInfos.real.dtype,
shape: b.shape
}, {
dataId: bData.complexTensorInfos.imag.dataId,
dtype: bData.complexTensorInfos.imag.dtype,
shape: b.shape
}];
var realPart = backend.runWebGLProgram(realProgram, _inputs, 'float32');
var imagPart = backend.runWebGLProgram(imagProgram, _inputs, 'float32');
var complexOutput = complex$2({
inputs: {
real: realPart,
imag: imagPart
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart); // TODO(annxingyuan): CPU forwarding for complex inputs.
return complexOutput;
}
if (backend.shouldExecuteOnCPU([a, b])) {
var _aData = backend.texData.get(a.dataId);
var _bData = backend.texData.get(b.dataId);
var _cpuMultiply = multiplyImplCPU(a.shape, b.shape, _aData.values, _bData.values, dtype),
outValues = _cpuMultiply[0],
outShape = _cpuMultiply[1];
var out = backend.makeTensorInfo(outShape, dtype);
var outData = backend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
var program;
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
} else {
program = new BinaryOpProgram(MUL, a.shape, b.shape);
}
return backend.runWebGLProgram(program, [a, b], dtype);
}
var multiplyConfig$1 = {
kernelName: Multiply,
backendName: 'webgl',
kernelFunc: multiply$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function packedReshape(input, afterShape, backend) {
var input3DShape = [getBatchDim(input.shape)].concat(getRowsCols(input.shape));
var input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
var afterShapeAs3D = [getBatchDim(afterShape)].concat(getRowsCols(afterShape));
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
var preventEagerUnpackingOfOutput = true;
var customValues = [input3DShape];
var output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
return {
dataId: output.dataId,
shape: afterShape,
dtype: output.dtype
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reshape$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var shape = attrs.shape;
var webglBackend = backend;
var xSize = sizeFromShape(x.shape);
var $shape = inferFromImplicitShape(shape, xSize);
var $xSize = sizeFromShape($shape);
assert(xSize === $xSize, function () {
return "The new shape (" + $shape + ") has " + $xSize + " elements and the old " + ("shape (" + x.shape + ") has " + xSize + " elements. The new shape and old ") + "shape must have the same number of elements.";
});
var xTexData = webglBackend.texData.get(x.dataId);
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
return packedReshape(x, $shape, webglBackend);
}
webglBackend.incRef(x.dataId);
return {
dataId: x.dataId,
shape: $shape,
dtype: x.dtype
};
}
var reshapeConfig$1 = {
kernelName: Reshape,
backendName: 'webgl',
kernelFunc: reshape$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MeanProgram = function MeanProgram(reduceInfo, divisor) {
this.variableNames = ['x'];
var windowSize = reduceInfo.windowSize,
batchSize = reduceInfo.batchSize,
inSize = reduceInfo.inSize,
outSize = reduceInfo.outSize;
this.outputShape = [batchSize, outSize];
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "sumValue += dot(values, ones);";
if (divisor != null) {
var denominator = 1 / divisor;
updateSnippet = "sumValue += dot(values * " + (isInt(denominator) ? denominator.toPrecision(2) : denominator) + ", ones);";
}
var checkOutOfBounds = '';
if (inSize % windowSize > 0) {
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return 0.0;\n }\n ";
}
this.userCode = "\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1), 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2), 0.0);\n\n " + updateSnippet + "\n }\n setOutput(sumValue);\n }\n ";
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReduceProgram = function ReduceProgram(reduceInfo, reduceType) {
this.variableNames = ['x'];
var windowSize = reduceInfo.windowSize,
batchSize = reduceInfo.batchSize,
inSize = reduceInfo.inSize,
outSize = reduceInfo.outSize;
this.outputShape = [batchSize, outSize];
var initializationValue = '0.0';
var compareOp = "";
if (reduceType === 'prod') {
initializationValue = '1.0';
} else if (reduceType === 'min') {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '1.0 / 1e-20';
compareOp = "min";
} else if (reduceType === 'max') {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '-1.0 / 1e-20';
compareOp = "max";
}
var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (reduceType === 'sum') {
returnValue = "sumValue";
} else if (reduceType === 'prod') {
returnValue = "prodValue";
} else if (reduceType === 'all') {
returnValue = "allValue";
} else if (reduceType === 'any') {
returnValue = "anyValue";
}
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n if (" + (reduceType === 'min') + " || " + (reduceType === 'max') + ") {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n bvec4 isNaN = isnan(values);\n if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {\n minMaxValue = vec4(NAN);\n }\n }\n }\n ";
var vecType = "vec4";
if (reduceType === 'all') {
initializationValue = '1.0';
updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
vecType = "bvec4";
} else if (reduceType === 'any') {
initializationValue = '0.0';
updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
vecType = "bvec4";
}
var checkOutOfBounds = '';
if (inSize % windowSize > 0) {
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
}
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// reduction.
function getReductionStages(inShape) {
var stages = [];
while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
var outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
var windowSize = computeOptimalWindowSize(outSize);
stages.push({
inSize: outSize,
windowSize: windowSize,
outSize: Math.ceil(outSize / windowSize)
});
}
return stages;
}
function reduce(x, dtype, reductionType, backend) {
var reductionStages = getReductionStages(x.shape);
var result = x;
for (var i = 0; i < reductionStages.length; i++) {
var _reductionStages$i = reductionStages[i],
inSize = _reductionStages$i.inSize,
windowSize = _reductionStages$i.windowSize,
outSize = _reductionStages$i.outSize;
var program = void 0;
var previousResult = void 0;
if (reductionType === 'mean') {
program = i === 0 ? new MeanProgram({
windowSize: windowSize,
inSize: inSize,
batchSize: x.shape[0],
outSize: outSize
}, inSize) : new MeanProgram({
windowSize: windowSize,
inSize: inSize,
batchSize: x.shape[0],
outSize: outSize
});
} else {
program = new ReduceProgram({
windowSize: windowSize,
inSize: inSize,
batchSize: x.shape[0],
outSize: outSize
}, reductionType);
}
previousResult = result;
result = backend.runWebGLProgram(program, [result], dtype);
if (previousResult.dataId !== x.dataId) {
backend.disposeIntermediateTensorInfo(previousResult);
}
}
return result;
}
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TransposeProgram = function TransposeProgram(aShape, newDim) {
this.variableNames = ['A'];
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var switched = getSwitchedCoords(newDim);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + switched + "));\n }\n ";
};
function getSwitchedCoords(newDim) {
var rank = newDim.length;
if (rank > 6) {
throw Error("Transpose for rank " + rank + " is not yet supported");
}
var originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
var switchedCoords = new Array(rank);
for (var i = 0; i < newDim.length; i++) {
switchedCoords[newDim[i]] = originalOrder[i];
}
return switchedCoords.join();
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TransposePackedProgram = function TransposePackedProgram(aShape, newDim) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
if (this.rank > 6) {
throw Error("Packed transpose for rank " + this.rank + " is not yet supported.");
}
var dtype = getCoordsDataType(this.rank);
var outputOrder = getVecChannels('rc', this.rank);
var switchedOrder = new Array(this.rank);
for (var _i = 0; _i < newDim.length; _i++) {
switchedOrder[newDim[_i]] = outputOrder[_i];
}
var innerDims = "vec2(" + switchedOrder.slice(-2).join() + ")";
var nextColumn = "++" + outputOrder[this.rank - 1] + " < " + outputShape[this.rank - 1];
var getc = "getChannel(getA(" + switchedOrder.join() + "), " + innerDims + ")";
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = " + getc + ";\n if(" + nextColumn + ") {\n result[1] = " + getc + ";\n }\n --" + outputOrder[this.rank - 1] + ";\n if(++" + outputOrder[this.rank - 2] + " < " + outputShape[this.rank - 2] + ") {\n result[2] = " + getc + ";\n if(" + nextColumn + ") {\n result[3] = " + getc + ";\n }\n }\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transposeImpl$1(x, perm, backend) {
var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm);
return backend.runWebGLProgram(program, [x], x.dtype);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sumImpl(x, axis, keepDims, backend) {
var reductionIndices = axis;
var xRank = x.shape.length;
var origAxes = parseAxisParam(reductionIndices, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var sumInputIsTransposed = permutedAxes != null;
var sumInput = x;
if (sumInputIsTransposed) {
sumInput = transposeImpl$1(x, permutedAxes, backend);
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('sum', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(sumInput.shape, axes),
sumOutShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outShape = sumOutShape;
if (keepDims) {
// rather than reshape at the end, set the target shape here.
outShape = expandShapeToKeepDim(sumOutShape, origAxes);
}
var inSize = sizeFromShape(reduceShape);
var xSize = sizeFromShape(x.shape);
var batchSize = xSize / inSize;
var reshapedInput = reshape$3({
inputs: {
x: sumInput
},
attrs: {
shape: [batchSize, inSize]
},
backend: backend
});
var outType = sumOutType(x.dtype);
var reduced = reduce(reshapedInput, outType, 'sum', backend);
var out = reshape$3({
inputs: {
x: reduced
},
attrs: {
shape: outShape
},
backend: backend
});
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
if (sumInputIsTransposed) {
backend.disposeIntermediateTensorInfo(sumInput);
}
return out;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sum$4(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
return sumImpl(x, axis, keepDims, backend);
}
var sumConfig$1 = {
kernelName: Sum,
backendName: 'webgl',
kernelFunc: sum$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transpose$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var perm = attrs.perm;
var webglBackend = backend;
var xRank = x.shape.length;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
var out;
if (webglBackend.shouldExecuteOnCPU([x])) {
var xTexData = webglBackend.texData.get(x.dataId);
var values = xTexData.values;
var outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
out = webglBackend.makeTensorInfo(newShape, x.dtype);
var outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
} else {
out = transposeImpl$1(x, perm, webglBackend);
}
return out;
}
var transposeConfig$1 = {
kernelName: Transpose,
backendName: 'webgl',
kernelFunc: transpose$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// to a.mul(b).sum() in order to take advantage of GPU parallelism. See
// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
var MATMUL_SHARED_DIM_THRESHOLD = 1000;
function batchMatMulImpl(_ref) {
var a = _ref.a,
b = _ref.b,
transposeA = _ref.transposeA,
transposeB = _ref.transposeB,
backend = _ref.backend,
_ref$bias = _ref.bias,
bias = _ref$bias === void 0 ? null : _ref$bias,
_ref$preluActivationW = _ref.preluActivationWeights,
preluActivationWeights = _ref$preluActivationW === void 0 ? null : _ref$preluActivationW,
_ref$leakyreluAlpha = _ref.leakyreluAlpha,
leakyreluAlpha = _ref$leakyreluAlpha === void 0 ? 0 : _ref$leakyreluAlpha,
_ref$activation = _ref.activation,
activation = _ref$activation === void 0 ? null : _ref$activation;
var aRank = a.shape.length;
var bRank = b.shape.length;
var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
var outerDimsA = a.shape.slice(0, -2);
var outerDimsB = b.shape.slice(0, -2);
var batchDimA = sizeFromShape(outerDimsA);
var batchDimB = sizeFromShape(outerDimsB);
var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
assert(aRank >= 2 && bRank >= 2 && batchDimsCompatible, function () {
return "Error in matMul: the input batch dimensions must either be the " + "same or at least one input batch dimension must be 1. Got input " + ("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ").");
});
var outShapeOuterDims = batchDimA > batchDimB ? a.shape.slice(0, -2) : b.shape.slice(0, -2);
var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
assert(innerShapeA === innerShapeB, function () {
return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + (innerShapeB + ") of Tensors with shapes " + a.shape + " and ") + (b.shape + " and transposeA=" + transposeA) + (" and transposeB=" + transposeB + " must match.");
});
var a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA];
var b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; // The rest of the implementation is designed to operate on rank-3 tensors
var a3d = reshape$3({
inputs: {
x: a
},
backend: backend,
attrs: {
shape: a3dShape
}
});
var b3d = reshape$3({
inputs: {
x: b
},
backend: backend,
attrs: {
shape: b3dShape
}
});
var intermediates = [a3d, b3d];
var batchDim = Math.max(batchDimA, batchDimB);
var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var hasLeakyreluAlpha = activation === 'leakyrelu';
var fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null;
var containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null;
var out; // Since the matrices are vectors, it is faster to call mul().sum()
// because sum() is O(sqrt(N)) due to divide-and-conquer.
if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
var aVec = a3d;
var bVec = b3d;
if (transposeA) {
aVec = transpose$2({
inputs: {
x: a3d
},
backend: backend,
attrs: {
perm: [0, 2, 1]
}
});
intermediates.push(aVec);
}
if (transposeB) {
bVec = transpose$2({
inputs: {
x: b3d
},
backend: backend,
attrs: {
perm: [0, 2, 1]
}
});
intermediates.push(bVec);
}
var shouldReshapeA = outerShapeB !== 1;
var shouldReshapeB = outerShapeB === 1;
var aVec3d = aVec;
if (shouldReshapeA) {
aVec3d = reshape$3({
inputs: {
x: aVec
},
backend: backend,
attrs: {
shape: [batchDim, sharedDim, 1]
}
});
intermediates.push(aVec3d);
}
var axis = outerShapeB === 1 ? 2 : 1;
var bVec3d = bVec;
if (shouldReshapeB) {
bVec3d = reshape$3({
inputs: {
x: bVec
},
backend: backend,
attrs: {
shape: [batchDim, 1, sharedDim]
}
});
intermediates.push(bVec3d);
}
var product = multiply$4({
inputs: {
a: aVec3d,
b: bVec3d
},
backend: backend
});
out = sum$4({
inputs: {
x: product
},
backend: backend,
attrs: {
axis: axis,
keepDims: true
}
});
intermediates.push(product);
} else {
var dtype = upcastType(a.dtype, b.dtype);
var program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
var inputs = [a3d, b3d];
if (bias != null) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
out = backend.runWebGLProgram(program, inputs, dtype);
}
var outReshaped = reshape$3({
inputs: {
x: out
},
backend: backend,
attrs: {
shape: outShape
}
});
intermediates.push(out);
for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
var i = _intermediates[_i];
backend.disposeIntermediateTensorInfo(i);
}
return outReshaped;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function _fusedMatMul$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var a = inputs.a,
b = inputs.b,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var transposeA = attrs.transposeA,
transposeB = attrs.transposeB,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
return batchMatMulImpl({
a: a,
b: b,
transposeA: transposeA,
transposeB: transposeB,
backend: backend,
bias: bias,
preluActivationWeights: preluActivationWeights,
leakyreluAlpha: leakyreluAlpha,
activation: activation
});
}
var _fusedMatMulConfig$1 = {
kernelName: _FusedMatMul,
backendName: 'webgl',
kernelFunc: _fusedMatMul$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ABS$1 = "return abs(x);";
function abs$a(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x; // TODO: handle cases when x is complex. Once the cpu implementation
// can handle complex values, refactor to use unaryKernelFunc.
if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
var xData = backend.texData.get(x.dataId);
var outValues = simpleAbsImplCPU(xData.values);
return backend.makeTensorInfo(x.shape, x.dtype, outValues);
}
var program;
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
program = new UnaryOpPackedProgram(x.shape, ABS$1);
} else {
program = new UnaryOpProgram(x.shape, ABS$1);
}
return backend.runWebGLProgram(program, [x], x.dtype);
}
var absConfig$1 = {
kernelName: Abs,
backendName: 'webgl',
kernelFunc: abs$a
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ACOS = CHECK_NAN_SNIPPET + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n";
var acos$2 = unaryKernelFunc$1({
opSnippet: ACOS
});
var acosConfig$1 = {
kernelName: Acos,
backendName: 'webgl',
kernelFunc: acos$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ACOSH = CHECK_NAN_SNIPPET + "\n if (x < 1.0) return NAN;\nreturn log(x + sqrt(x * x - 1.0));";
var acosh$2 = unaryKernelFunc$1({
opSnippet: ACOSH
});
var acoshConfig$1 = {
kernelName: Acosh,
backendName: 'webgl',
kernelFunc: acosh$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ADD = 'return a + b;';
var addKernelFunc = binaryKernelFunc$1({
opSnippet: ADD,
packedOpSnippet: ADD,
supportsComplex: true,
cpuKernelImpl: addImplCPU
});
var addConfig$1 = {
kernelName: Add,
backendName: 'webgl',
kernelFunc: addKernelFunc
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AddNProgram = function AddNProgram(outputShape, shapes) {
this.outputShape = [];
this.outputShape = outputShape;
this.variableNames = shapes.map(function (_, i) {
return "T" + i;
});
var snippets = []; // Get target elements from every input tensor.
this.variableNames.forEach(function (variable) {
snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();");
}); // Calculate the sum of all elements.
var operation = this.variableNames.map(function (variable) {
return "v" + variable;
}).join(' + ');
this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AddNPackedProgram = function AddNPackedProgram(outputShape, shapes) {
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
this.variableNames = shapes.map(function (_, i) {
return "T" + i;
});
var snippets = []; // Get target elements from every input tensor.
this.variableNames.forEach(function (variable) {
snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();");
}); // Calculate the sum of all elements.
var operation = this.variableNames.map(function (variable) {
return "v" + variable;
}).join(' + ');
this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function addN$2(args) {
var inputs = args.inputs,
backend = args.backend;
var tensors = inputs;
if (tensors.length === 1) {
return identity$2({
inputs: {
x: tensors[0]
},
backend: backend
});
} // Limit the number of uploaded textures for optimization.
if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) {
var midIndex = Math.floor(tensors.length / 2);
var leftSide = addN$2({
inputs: tensors.slice(0, midIndex),
backend: backend
});
var rightSide = addN$2({
inputs: tensors.slice(midIndex),
backend: backend
});
return addN$2({
inputs: [leftSide, rightSide],
backend: backend
});
}
var dtype = tensors.map(function (t) {
return t.dtype;
}).reduce(function (d1, d2) {
return upcastType(d1, d2);
});
var shapes = tensors.map(function (t) {
return t.shape;
}); // We can make sure shapes are identical in op level.
var usePackedOp = env().getBool('WEBGL_PACK');
var program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes);
return backend.runWebGLProgram(program, tensors, dtype);
}
var addNConfig$1 = {
kernelName: AddN,
backendName: 'webgl',
kernelFunc: addN$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function all$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
var xRank = x.shape.length;
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var permutedX = x;
if (permutedAxes != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('all', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var inSize = sizeFromShape(reduceShape);
var a2D = reshape$3({
inputs: {
x: permutedX
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
var reduced = reduce(a2D, a2D.dtype, 'all', backend);
var res;
if (keepDims) {
var newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: newShape
}
});
} else {
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: outShape
}
});
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
var allConfig$1 = {
kernelName: All,
backendName: 'webgl',
kernelFunc: all$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function any$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
var xRank = x.shape.length;
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var permutedX = x;
if (permutedAxes != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('any', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var inSize = sizeFromShape(reduceShape);
var a2D = reshape$3({
inputs: {
x: permutedX
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
var reduced = reduce(a2D, a2D.dtype, 'any', backend);
var res;
if (keepDims) {
var newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: newShape
}
});
} else {
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: outShape
}
});
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
var anyConfig$1 = {
kernelName: Any,
backendName: 'webgl',
kernelFunc: any$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ArgMinMaxProgram = function ArgMinMaxProgram(reduceInfo, op, firstPass) {
this.variableNames = ['A'];
var windowSize = reduceInfo.windowSize,
batchSize = reduceInfo.batchSize,
outSize = reduceInfo.outSize;
if (!firstPass) {
this.variableNames.push('bestIndicesA');
}
this.outputShape = [batchSize, outSize];
var compOp = op === 'max' ? '>' : '<';
var indexSnippet = firstPass ? 'inOffset + i;' : 'round(getBestIndicesA(batch, inOffset + i));';
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ArgMinMaxPackedProgram = function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
assert(shape.length > 2, function () {
return "Packed arg" + (op.charAt(0).toUpperCase() + op.slice(1)) + " supports only inputs with rank above 2.";
});
var inSize = shape[shape.length - 1];
var outSize = Math.ceil(inSize / windowSize);
this.outputShape = shape.slice(0, -1);
if (outSize > 1) {
this.outputShape.push(outSize);
}
if (!firstPass) {
this.variableNames.push('bestIndicesA');
}
var outShape = this.outputShape;
var rank = outShape.length;
var dtype = getCoordsDataType(rank);
var coords = getChannels('coords', rank);
var sourceLocSetup;
var sourceRank;
if (outSize === 1) {
sourceRank = rank + 1;
var sourceLocDType = getCoordsDataType(sourceRank);
sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";";
} else {
sourceRank = rank;
sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";";
}
var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
var intChannels = channels.map(function (x) {
return 'int ' + x;
});
var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
var compOp = op === 'max' ? 'greaterThan' : 'lessThan';
var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));";
var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)";
var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }";
this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argReduce(backend, x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) {
bestIndicesA = null;
}
var batchSize = x.shape[0];
var inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
var windowSize = computeOptimalWindowSize(inSize);
var reduceInfo = {
windowSize: windowSize,
inSize: inSize,
batchSize: batchSize,
outSize: Math.ceil(inSize / windowSize)
};
var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
var inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
var output = backend.runWebGLProgram(program, inputs, 'int32'); // No need to run another GPGPU program.
if (output.shape[1] === 1) {
return output;
}
var result = argReduce(backend, x, reduceType, output);
backend.disposeIntermediateTensorInfo(output);
return result;
}
function argReducePacked(backend, x, reduceType, bestIndicesA) {
if (bestIndicesA === void 0) {
bestIndicesA = null;
}
var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
var inSize = inShape[inShape.length - 1];
var windowSize = computeOptimalWindowSize(inSize);
var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
var output = backend.runWebGLProgram(program, inputs, 'int32');
if (output.shape.length === x.shape.length) {
var result = argReducePacked(backend, x, reduceType, output);
backend.disposeIntermediateTensorInfo(output);
return result;
}
return output;
}
function argMinMaxReduce(backend, x, axis, reduceType) {
var axes = [axis];
assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
var intermediateTensorInfos = []; // Eagerly unpack x input since it is passed in to all the shaders which
// require unpacked inputs.
var xtexData = backend.texData.get(x.dataId);
var xIsPacked = xtexData !== null && xtexData.isPacked;
var xUnPacked = x;
if (xIsPacked) {
xUnPacked = backend.unpackTensor(x);
intermediateTensorInfos.push(xUnPacked);
}
var _backend_util$compute = computeOutAndReduceShapes(xUnPacked.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var inSize = sizeFromShape(reduceShape);
var a2D = reshape$3({
inputs: {
x: xUnPacked
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
intermediateTensorInfos.push(a2D);
var reduced = argReduce(backend, a2D, reduceType);
intermediateTensorInfos.push(reduced);
var reshaped = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: outShape
}
});
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return reshaped;
}
return argReducePacked(backend, x, reduceType);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMax$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis;
var axes = parseAxisParam(axis, x.shape);
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
var intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
var out = argMinMaxReduce(backend, $x, axes[0], 'max');
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return out;
}
var argMaxConfig$1 = {
kernelName: ArgMax,
backendName: 'webgl',
kernelFunc: argMax$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function argMin$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis;
var axes = parseAxisParam(axis, x.shape);
var permutedAxes = getAxesPermutation(axes, x.shape.length);
var $x = x;
var intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
var out = argMinMaxReduce(backend, $x, axes[0], 'min');
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return out;
}
var argMinConfig$1 = {
kernelName: ArgMin,
backendName: 'webgl',
kernelFunc: argMin$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ASIN = CHECK_NAN_SNIPPET + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n";
var asin$2 = unaryKernelFunc$1({
opSnippet: ASIN
});
var asinConfig$1 = {
kernelName: Asin,
backendName: 'webgl',
kernelFunc: asin$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ASINH = CHECK_NAN_SNIPPET + "return log(x + sqrt(x * x + 1.0));";
var asinh$3 = unaryKernelFunc$1({
opSnippet: ASINH
});
var asinhConfig$1 = {
kernelName: Asinh,
backendName: 'webgl',
kernelFunc: asinh$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ATAN = CHECK_NAN_SNIPPET + "\n return atan(x);\n";
var atan$2 = unaryKernelFunc$1({
opSnippet: ATAN
});
var atanConfig$1 = {
kernelName: Atan,
backendName: 'webgl',
kernelFunc: atan$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ATAN2 = CHECK_NAN_SNIPPET_BINARY + "\n return atan(a, b);\n";
var ATAN2_PACKED = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET_BINARY_PACKED + "\n return result;\n";
var atan2$2 = binaryKernelFunc$1({
opSnippet: ATAN2,
packedOpSnippet: ATAN2_PACKED
});
var atan2Config$1 = {
kernelName: Atan2,
backendName: 'webgl',
kernelFunc: atan2$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ATANH = CHECK_NAN_SNIPPET + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\nreturn (log(1.0 + x) - log(1.0 - x)) / 2.0;";
var atanh$2 = unaryKernelFunc$1({
opSnippet: ATANH
});
var atanhConfig$1 = {
kernelName: Atanh,
backendName: 'webgl',
kernelFunc: atanh$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Pool2DProgram = function Pool2DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
this.variableNames = ['x'];
if (poolType === 'avg' && computePositions) {
throw new Error('Cannot compute positions for average pool.');
}
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
var isAvgPool = poolType === 'avg';
var batchFlattenPositionStr = "((batch * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
var flattenPositionStr = "(xR * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
var initializationValue = '0.0';
if (!isAvgPool) {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '-1.0 / 1e-20';
}
if (computePositions) {
var _compareOp = '>=';
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + _compareOp + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr : "wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
return;
}
var compareOp = 'max';
var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = "avgValue / count";
}
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
var filterWidthVec4Remainder = filterWidth % 4;
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n ";
};
var Pool3DProgram = function Pool3DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
if (flattenPositions === void 0) {
flattenPositions = false;
}
if (includeBatchInIndex === void 0) {
includeBatchInIndex = false;
}
this.variableNames = ['x'];
if (poolType === 'avg' && computePositions) {
throw new Error('Cannot compute positions for average pool.');
}
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
var isAvgPool = poolType === 'avg';
var initializationValue = '0.0';
if (!isAvgPool) {
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
initializationValue = '-1.0 / 1e-20';
}
if (computePositions) {
var _compareOp2 = '>=';
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + _compareOp2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? includeBatchInIndex ? "(((batch * " + convInfo.inDepth + " + xD) * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" : "((xD * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" : "wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
return;
}
var compareOp = 'max';
var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = "avgValue / count";
}
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
var filterWidthVec4Remainder = filterWidth % 4;
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
assertNotComplex$1(x, 'avgPool');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in avgPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity$2({
inputs: {
x: x
},
backend: backend
});
}
var avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
}
var avgPoolConfig$1 = {
kernelName: AvgPool,
backendName: 'webgl',
kernelFunc: avgPool$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3D$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
dataFormat = attrs.dataFormat;
var dilations = [1, 1, 1];
var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
var avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
}
var avgPool3DConfig$1 = {
kernelName: AvgPool3D,
backendName: 'webgl',
kernelFunc: avgPool3D$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var AvgPool2DBackpropProgram = function AvgPool2DBackpropProgram(convInfo) {
this.variableNames = ['dy'];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var avgMultiplier = 1 / (filterHeight * filterWidth);
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ";
};
var AvgPool3DBackpropProgram = function AvgPool3DBackpropProgram(convInfo) {
this.variableNames = ['dy'];
this.outputShape = convInfo.inShape;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPool3DGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var x = input;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = [1, 1, 1];
var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
}
var avgPoolGrad3DConfig = {
kernelName: AvgPool3DGrad,
backendName: 'webgl',
kernelFunc: avgPool3DGrad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function avgPoolGrad$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var x = input;
assertNotComplex$1([dy, input], 'avgPoolGrad');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad;
var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad);
var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
}
var avgPoolGradConfig$2 = {
kernelName: AvgPoolGrad,
backendName: 'webgl',
kernelFunc: avgPoolGrad$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function batchMatMul$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var a = inputs.a,
b = inputs.b;
var transposeA = attrs.transposeA,
transposeB = attrs.transposeB;
return batchMatMulImpl({
a: a,
b: b,
transposeA: transposeA,
transposeB: transposeB,
backend: backend
});
}
var batchMatMulConfig$1 = {
kernelName: BatchMatMul,
backendName: 'webgl',
kernelFunc: batchMatMul$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var BatchNormProgram = function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.outputShape = [];
this.variableNames = ['x', 'mean', 'variance'];
assertAndGetBroadcastShape(xShape, meanShape);
assertAndGetBroadcastShape(xShape, varianceShape);
var offsetSnippet = '0.0';
if (offsetShape != null) {
assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push('offset');
offsetSnippet = 'getOffsetAtOutCoords()';
}
var scaleSnippet = '1.0';
if (scaleShape != null) {
assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push('scale');
scaleSnippet = 'getScaleAtOutCoords()';
}
this.outputShape = xShape;
this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var BatchNormPackedProgram = function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.packedInputs = true;
this.packedOutput = true;
this.variableNames = ['x', 'mean', 'variance'];
assertAndGetBroadcastShape(xShape, meanShape);
assertAndGetBroadcastShape(xShape, varianceShape);
var offsetSnippet = 'vec4(0.0)';
if (offsetShape != null) {
assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push('offset');
offsetSnippet = 'getOffsetAtOutCoords()';
}
var scaleSnippet = 'vec4(1.0)';
if (scaleShape != null) {
assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push('scale');
scaleSnippet = 'getScaleAtOutCoords()';
}
this.outputShape = xShape;
this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchNorm$2 = function batchNorm(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend,
attrs = _ref.attrs;
var x = inputs.x,
mean = inputs.mean,
variance = inputs.variance,
offset = inputs.offset,
scale = inputs.scale;
assert(mean.shape.length === variance.shape.length, function () {
return 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.';
});
assert(offset == null || mean.shape.length === offset.shape.length, function () {
return 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.';
});
assert(scale == null || mean.shape.length === scale.shape.length, function () {
return 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.';
});
var varianceEpsilon = attrs.varianceEpsilon;
if (varianceEpsilon == null) {
varianceEpsilon = 0.001;
}
var finalInputs = [x, mean, variance];
var offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
finalInputs.push(offset);
}
var scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
finalInputs.push(scale);
}
var program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
var output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
return output;
};
var batchNormConfig$1 = {
kernelName: FusedBatchNorm,
backendName: 'webgl',
kernelFunc: batchNorm$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SliceProgram = function SliceProgram(destSize) {
this.variableNames = ['source'];
this.outputShape = destSize;
this.rank = destSize.length;
var dtype = getCoordsDataType(this.rank);
this.customUniforms = [{
name: 'start',
arrayIndex: this.rank,
type: 'int'
}];
var sourceCoords = getCoords(this.rank);
var body;
var coordSum = destSize.map(function (_, i) {
return "sourceLoc." + coords[i] + " = start[" + i + "] + coords." + coords[i] + ";";
});
body = "\n " + dtype + " sourceLoc;\n " + dtype + " coords = getOutputCoords();\n " + coordSum.join('\n') + "\n ";
this.userCode = "\n void main() {\n " + body + "\n setOutput(getSource(" + sourceCoords + "));\n }\n ";
};
var coords = ['x', 'y', 'z', 'w', 'u', 'v'];
function getCoords(rank) {
if (rank === 1) {
return 'sourceLoc';
} else if (rank <= 6) {
return coords.slice(0, rank).map(function (x) {
return 'sourceLoc.' + x;
}).join(',');
} else {
throw Error("Slicing for rank " + rank + " is not yet supported");
}
}
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SlicePackedProgram = function SlicePackedProgram(destSize) {
this.variableNames = ['source'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = destSize;
this.rank = destSize.length;
this.customUniforms = [{
name: 'start',
arrayIndex: this.rank,
type: 'int'
}];
var dtype = getCoordsDataType(this.rank);
var coords = getChannels('coords', this.rank);
var sourceLoc = getChannels('sourceLoc', this.rank);
var innerDims = this.rank === 1 ? 'sourceLoc' : "vec2(" + sourceLoc.slice(-2).join() + ")";
var getChannel = "getChannel(getSource(" + sourceLoc.join() + "), " + innerDims + ")";
var upperRow = "\n result.x = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.y = " + getChannel + ";\n --" + sourceLoc[this.rank - 1] + ";\n }\n ";
var lowerRow = this.rank === 1 ? '' : "\n --" + coords[this.rank - 1] + ";\n if (++" + coords[this.rank - 2] + " < " + destSize[this.rank - 2] + ") {\n ++" + sourceLoc[this.rank - 2] + ";\n result.z = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.w = " + getChannel + ";\n }\n }\n ";
var sourceLocSetup = this.rank <= 4 ? "sourceLoc = coords +\n " + dtype + "(" + destSize.map(function (_, i) {
return "start[" + i + "]";
}).join() + ");" : destSize.map(function (_, i) {
return sourceLoc[i] + " = " + coords[i] + " + start[" + i + "];";
}).join('\n');
this.userCode = "\n void main() {\n " + dtype + " coords = getOutputCoords();\n " + dtype + " sourceLoc;\n " + sourceLocSetup + "\n vec4 result = vec4(0.);\n " + upperRow + "\n " + lowerRow + "\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function shallowSlice(x, begin, size, backend) {
var xTexData = backend.texData.get(x.dataId);
var t = backend.makeTensorInfo(size, x.dtype);
var newTexData = backend.texData.get(t.dataId); // Copy texture data from the original tensor.
Object.assign(newTexData, xTexData);
newTexData.refCount = 1;
newTexData.shape = size;
newTexData.dtype = x.dtype;
var flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
if (xTexData.slice) {
// We are slicing an already sliced tensor, so we have to accumulate
// the offset.
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset: flatOffset,
// Point to the original dataId, which is used to do ref counting.
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
}; // Increase the ref count for that data bucket.
var refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
}
function slice$4(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var begin = attrs.begin,
size = attrs.size;
var _slice_util$parseSlic = parseSliceParams(x, begin, size),
$begin = _slice_util$parseSlic[0],
$size = _slice_util$parseSlic[1];
assertParamsValid(x, $begin, $size);
if (sizeFromShape($size) === 0) {
return backend.makeTensorInfo($size, x.dtype, []);
} // Run on cpu if dtype is string. For string, the backend represents it
// as Uint8Array[], where each Uint8Array is a character. Given that the
// computation is only on the outer array, uploading the whole data onto
// gpu is wasteful. Also, currently webgl doesn't have a design to
// upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
// just run the kernel on cpu if dtype is string.
if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
var xTexData = backend.texData.get(x.dataId);
var outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
return backend.makeTensorInfo($size, x.dtype, outValues);
}
var _backend$texData$get = backend.texData.get(x.dataId),
isPacked = _backend$texData$get.isPacked;
var isContinous = isSliceContinous(x.shape, $begin, $size);
if (isPacked || !isContinous) {
var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram($size) : new SliceProgram($size);
var customValues = [$begin];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
}
backend.uploadToGPU(x.dataId);
return shallowSlice(x, $begin, $size, backend);
}
var sliceConfig$1 = {
kernelName: Slice,
backendName: 'webgl',
kernelFunc: slice$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var batchToSpaceND$2 = function batchToSpaceND(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockShape = attrs.blockShape,
crops = attrs.crops;
assert(x.shape.length <= 4, function () {
return 'batchToSpaceND for rank > 4 with a WebGL backend not ' + 'implemented yet';
});
var prod = blockShape.reduce(function (a, b) {
return a * b;
});
var reshaped = getReshaped(x.shape, blockShape, prod);
var permuted = getPermuted(reshaped.length, blockShape.length);
var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
var toDispose = [];
var reshapedIntermediate = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: reshaped
}
});
var transposedIntermediate = transpose$2({
inputs: {
x: reshapedIntermediate
},
backend: backend,
attrs: {
perm: permuted
}
});
var reshapedIntermediate2 = reshape$3({
inputs: {
x: transposedIntermediate
},
backend: backend,
attrs: {
shape: reshapedPermuted
}
});
var sliced = slice$4({
inputs: {
x: reshapedIntermediate2
},
backend: backend,
attrs: {
begin: sliceBeginCoords,
size: sliceSize
}
});
toDispose.push(reshapedIntermediate);
toDispose.push(transposedIntermediate);
toDispose.push(reshapedIntermediate2);
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return sliced;
};
var batchToSpaceNDConfig$1 = {
kernelName: BatchToSpaceND,
backendName: 'webgl',
kernelFunc: batchToSpaceND$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function bincount$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
weights = inputs.weights;
var size = attrs.size;
var xVals = backend.readSync(x.dataId);
var weightsVals = backend.readSync(weights.dataId);
var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
var bincountConfig$1 = {
kernelName: Bincount,
backendName: 'webgl',
kernelFunc: bincount$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var NOT_EQUAL$1 = "return float(a != b);";
var notEqual$2 = binaryKernelFunc$1({
opSnippet: NOT_EQUAL$1,
cpuKernelImpl: notEqualImplCPU,
dtype: 'bool'
});
var notEqualConfig$1 = {
kernelName: NotEqual,
backendName: 'webgl',
kernelFunc: notEqual$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function real$2(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var inputData = backend.texData.get(input.dataId);
return identity$2({
inputs: {
x: inputData.complexTensorInfos.real
},
backend: backend
});
}
var realConfig$1 = {
kernelName: Real,
backendName: 'webgl',
kernelFunc: real$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TO_INT = "return float(int(x));";
function int(input, backend) {
var program = new UnaryOpProgram(input.shape, TO_INT);
var output = backend.runWebGLProgram(program, [input], 'int32');
return {
dataId: output.dataId,
shape: output.shape,
dtype: output.dtype
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cast$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var dtype = attrs.dtype; // Casting to complex64.
if (dtype === 'complex64') {
if (x.dtype === 'complex64') {
return identity$2({
inputs: {
x: x
},
backend: backend
});
} // TODO(annxingyuan): Import kernel function once zeros is modularized.
var zerosTensor = zeros(x.shape);
var floatX = cast$3({
inputs: {
x: x
},
backend: backend,
attrs: {
dtype: 'float32'
}
});
var result = complex$2({
inputs: {
real: floatX,
imag: zerosTensor
},
backend: backend
});
zerosTensor.dispose();
backend.disposeIntermediateTensorInfo(floatX);
return result;
} // Casting from complex64
if (x.dtype === 'complex64') {
var realPart = real$2({
inputs: {
input: x
},
backend: backend
});
var _result = cast$3({
inputs: {
x: realPart
},
backend: backend,
attrs: {
dtype: dtype
}
});
backend.disposeIntermediateTensorInfo(realPart);
return _result;
}
if (!hasEncodingLoss(x.dtype, dtype)) {
// We don't change the underlying data, since we cast to higher
// precision.
var _result2 = identity$2({
inputs: {
x: x
},
backend: backend
});
return {
dataId: _result2.dataId,
shape: _result2.shape,
dtype: dtype
};
}
if (dtype === 'int32') {
return int(x, backend);
}
if (dtype === 'bool') {
var zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
var binaryInputs = {
a: x,
b: zerosTensorInfo
};
var _result3 = notEqual$2({
inputs: binaryInputs,
backend: backend
});
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
return _result3;
}
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
}
var castConfig$1 = {
kernelName: Cast,
backendName: 'webgl',
kernelFunc: cast$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CEIL = "return ceil(x);";
var ceil$5 = unaryKernelFunc$1({
opSnippet: CEIL,
packedOpSnippet: CEIL,
cpuKernelImpl: ceilImplCPU
});
var ceilConfig$1 = {
kernelName: Ceil,
backendName: 'webgl',
kernelFunc: ceil$5
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ClipProgram = function ClipProgram(aShape) {
this.variableNames = ['A'];
this.customUniforms = [{
name: 'minVal',
type: 'float'
}, {
name: 'maxVal',
type: 'float'
}];
this.outputShape = aShape;
this.userCode = "\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ClipPackedProgram = function ClipPackedProgram(aShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{
name: 'minVal',
type: 'float'
}, {
name: 'maxVal',
type: 'float'
}];
this.outputShape = aShape;
this.userCode = "\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function clipByValue$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var clipValueMin = attrs.clipValueMin,
clipValueMax = attrs.clipValueMax;
var program;
if (env().getBool('WEBGL_PACK_CLIP')) {
program = new ClipPackedProgram(x.shape);
} else {
program = new ClipProgram(x.shape);
}
var customValues = [[clipValueMin], [clipValueMax]];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
}
var clipByValueConfig = {
kernelName: ClipByValue,
backendName: 'webgl',
kernelFunc: clipByValue$1
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ComplexAbsProgram = function ComplexAbsProgram(shape) {
this.variableNames = ['real', 'imag'];
this.outputShape = shape;
this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// underlying part. We need to do this because a reshaped complex tensor is
// not reflected in its parts.
function makeComplexComponentTensorInfo(complexTensor, complexPart) {
return {
dataId: complexPart.dataId,
dtype: complexPart.dtype,
shape: complexTensor.shape
};
}
function complexAbs$1(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
var xData = backend.texData.get(x.dataId);
var program = new ComplexAbsProgram(x.shape);
var programInputs = [makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real), makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag)];
return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
}
var complexAbsConfig$1 = {
kernelName: ComplexAbs,
backendName: 'webgl',
kernelFunc: complexAbs$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ConcatProgram = // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
function ConcatProgram(shapes) {
this.outputShape = [];
this.outputShape = computeOutShape$1(shapes, 1
/* axis */
);
this.variableNames = shapes.map(function (_, i) {
return "T" + i;
});
var offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
for (var i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][1];
}
var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"];
for (var _i = 1; _i < offsets.length; _i++) {
var shift = offsets[_i - 1];
snippets.push("else if (yC < " + offsets[_i] + ") " + ("setOutput(getT" + _i + "(yR, yC-" + shift + "));"));
}
var lastIndex = offsets.length;
var lastShift = offsets[offsets.length - 1];
snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));");
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join('\n ') + "\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ConcatPackedProgram = function ConcatPackedProgram(shapes, axis) {
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
this.outputShape = computeOutShape$1(shapes, axis);
var shape = this.outputShape;
var rank = shape.length;
var dtype = getCoordsDataType(rank);
var coords = getChannels('coords', rank);
var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
this.variableNames = shapes.map(function (_, i) {
return "T" + i;
});
var offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][axis];
for (var i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][axis];
}
var channel = channels[axis];
var lastChannels = channels.slice(-2);
var allChannels = channels.join();
var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }";
for (var _i = 1; _i < offsets.length; _i++) {
var _shift = offsets[_i - 1]; // Note: the >= comparison below may seem unnecessary given the check
// above but is needed to workaround branch execution issues on some
// devices. It makes all the conditions exclusive without relying on
// execution order.
getValueSnippet += "\n if (" + channel + " < " + offsets[_i] + " && " + channel + " >= " + offsets[_i - 1] + ") {\n return getChannel(\n getT" + _i + "(" + shiftedChannels(channels, channel, _shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, _shift) + "));\n }";
}
var lastIndex = offsets.length;
var shift = offsets[offsets.length - 1];
getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));";
this.userCode = "\n float getValue(" + channels.map(function (x) {
return 'int ' + x;
}) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords + "), 0., 0., 0.);\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " + 1;\n if (" + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords + ");\n }\n\n " + coords[rank - 2] + " = " + coords[rank - 2] + " + 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords + ");\n }\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " - 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords + ");\n }\n setOutput(result);\n }\n ";
};
/**
* Return an expression for coordinates into a vector where a given channel
* will be offset by [shift].
*
* @param channels the channels to consider
* @param channel the channel we want shifted
* @param shift the amount to subtract from the channel.
*
* @returns a string of the form 'x, y-[shift], z' where any one channel can
* have the shift applied.
*/
function shiftedChannels(channels, channel, shift) {
var channelIdx = channels.indexOf(channel);
var res = channels.map(function (c, idx) {
if (idx === channelIdx) {
return c + " - " + shift;
} else {
return c;
}
});
return res.join();
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function imag$2(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
var inputData = backend.texData.get(input.dataId);
return identity$2({
inputs: {
x: inputData.complexTensorInfos.imag
},
backend: backend
});
}
var imagConfig$1 = {
kernelName: Imag,
backendName: 'webgl',
kernelFunc: imag$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concatImpl$1(inputs, axis, backend) {
var dtype = inputs[0].dtype;
if (dtype === 'complex64') {
var reals = inputs.map(function (t) {
return real$2({
inputs: {
input: t
},
backend: backend
});
});
var imags = inputs.map(function (t) {
return imag$2({
inputs: {
input: t
},
backend: backend
});
});
var realConcated = concatImpl$1(reals, axis, backend);
var imagConcated = concatImpl$1(imags, axis, backend);
var _result = complex$2({
inputs: {
real: realConcated,
imag: imagConcated
},
backend: backend
});
reals.forEach(function (r) {
return backend.disposeIntermediateTensorInfo(r);
});
imags.forEach(function (i) {
return backend.disposeIntermediateTensorInfo(i);
});
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return _result;
}
var runOnCpu = backend.shouldExecuteOnCPU(inputs); // Run on cpu if dtype is string. For string, the backend represents it
// as Uint8Array[], where each Uint8Array is a character. Given that the
// computation is only on the outer array, uploading the whole data onto
// gpu is wasteful. Also, currently webgl doesn't have a design to
// upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
// just run the kernel on cpu if dtype is string.
if (dtype === 'string') {
runOnCpu = true;
}
if (runOnCpu) {
// Any concat of n-dimensional tensors across any axis can be reduced to
// a concatenation of two-dimensional tensors across the axis 1 by first
// partitioning the axes of the original tensors into those less than the
// axis to be concatenated and the rest. Then reshape the tensors
// into a two-dimensional tensor by collapsing these two sets of axes and
// concatenate the resulting matrices across the axis 1, finally reshaping
// the result to have the proper shape.
var _tensors2D = inputs.map(function (t) {
var innerSize = sizeFromShape(t.shape.slice(axis));
var shape = [-1, innerSize];
return reshape$3({
inputs: {
x: t
},
backend: backend,
attrs: {
shape: shape
}
});
});
var inputsValShapes = _tensors2D.map(function (t) {
return {
vals: backend.readSync(t.dataId),
shape: t.shape
};
}); // Concats 2d tensors along axis=1.
var _outShape = computeOutShape$1(_tensors2D.map(function (t) {
return t.shape;
}), 1
/* axis */
);
var simplyConcat = _tensors2D[0].shape[0] === 1;
var outVals = concatImplCPU(inputsValShapes, _outShape, dtype, simplyConcat);
var finalOutShape = computeOutShape$1(inputs.map(function (t) {
return t.shape;
}), axis);
var outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
_tensors2D.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return outInfo;
}
if (inputs.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
var midIndex = Math.floor(inputs.length / 2);
var leftSide = concatImpl$1(inputs.slice(0, midIndex), axis, backend);
var rightSide = concatImpl$1(inputs.slice(midIndex), axis, backend);
var _result2 = concatImpl$1([leftSide, rightSide], axis, backend);
backend.disposeIntermediateTensorInfo(leftSide);
backend.disposeIntermediateTensorInfo(rightSide);
return _result2;
}
if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && inputs[0].shape.length > 1) {
var _program = new ConcatPackedProgram(inputs.map(function (t) {
return t.shape;
}), axis);
return backend.runWebGLProgram(_program, inputs, dtype);
}
var _computeTensors2D = computeTensors2D(inputs, axis, backend),
tensors2D = _computeTensors2D.tensors2D,
outShape = _computeTensors2D.outShape;
var program = new ConcatProgram(tensors2D.map(function (t) {
return t.shape;
}));
var result = backend.runWebGLProgram(program, tensors2D, dtype);
tensors2D.forEach(function (r) {
return backend.disposeIntermediateTensorInfo(r);
});
var reshapedResult = reshape$3({
inputs: {
x: result
},
attrs: {
shape: outShape
},
backend: backend
});
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
function computeTensors2D(inputs, axis, backend) {
// Any concat of n-dimensional tensors across any axis can be reduced to
// a concatenation of two-dimensional tensors across the axis 1 by first
// partitioning the axes of the original tensors into those less than the
// axis to be concatenated and the rest. Then reshape the tensors
// into a two-dimensional tensor by collapsing these two sets of axes and
// concatenate the resulting matrices across the axis 1, finally reshaping
// the result to have the proper shape.
var outShape = computeOutShape$1(inputs.map(function (t) {
return t.shape;
}), axis);
var tensors2D = inputs.map(function (x) {
return reshape$3({
inputs: {
x: x
},
attrs: {
shape: [-1, sizeFromShape(x.shape.slice(axis))]
},
backend: backend
});
});
return {
tensors2D: tensors2D,
outShape: outShape
};
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function concat$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var axis = attrs.axis;
var $axis = parseAxisParam(axis, inputs[0].shape)[0];
var outShape = computeOutShape$1(inputs.map(function (t) {
return t.shape;
}), $axis);
if (sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
} // Keep only non-empty tensors (ignore tensors with 0 in their shape).
var $inputs = inputs.filter(function (t) {
return sizeFromShape(t.shape) > 0;
});
if ($inputs.length === 1) {
return identity$2({
inputs: {
x: $inputs[0]
},
backend: backend
});
}
var shapes = $inputs.map(function (t) {
return t.shape;
});
assertParamsConsistent(shapes, $axis);
return concatImpl$1($inputs, $axis, backend);
}
var concatConfig$1 = {
kernelName: Concat,
backendName: 'webgl',
kernelFunc: concat$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Conv2DProgram = function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights, hasLeakyreluAlpha) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivationWeights === void 0) {
hasPreluActivationWeights = false;
}
if (hasLeakyreluAlpha === void 0) {
hasLeakyreluAlpha = false;
}
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
var inputDepthVec4Remainder = convInfo.inChannels % 4;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
var rowDim = isChannelsLast ? 1 : 2;
var colDim = isChannelsLast ? 2 : 3;
var channelDim = isChannelsLast ? 3 : 1;
var activationSnippet = '',
applyActivationSnippet = '';
if (activation) {
if (hasPreluActivationWeights) {
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else if (hasLeakyreluAlpha) {
activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyreluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
};
var Conv3DProgram = function Conv3DProgram(convInfo) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
var inputDepthVec4Remainder = convInfo.inChannels % 4;
this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Im2ColPackedProgram = function Im2ColPackedProgram(outputShape, convInfo) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{
name: 'inputShape',
type: 'ivec3'
}, {
name: 'pad',
type: 'ivec2'
}, {
name: 'stride',
type: 'ivec2'
}, {
name: 'dilation',
type: 'ivec2'
}, {
name: 'inChannels',
type: 'int'
}, {
name: 'itemsPerBlockRow',
type: 'int'
}, {
name: 'outWidth',
type: 'int'
}];
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var dataFormat = convInfo.dataFormat;
var glsl = getGlslDifferences();
var isChannelsLast = dataFormat === 'channelsLast';
var rowDim = isChannelsLast ? 0 : 1;
var colDim = isChannelsLast ? 1 : 2;
var boundsCheckingSnippet = this.enableShapeUniforms ? 'if(blockIndex < outShape[1] && pos < outShape[0]) {' : "if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {";
var unrolled = "";
for (var row = 0; row <= 1; row++) {
for (var col = 0; col <= 1; col++) {
unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n " + boundsCheckingSnippet + "\n offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];\n d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);\n\n if(d0 < inputShape[" + rowDim + "] && d0 >= 0) {\n // Use custom imod instead mod. On Intel GPU, mod may generate\n // unexpected value.\n // https://github.com/tensorflow/tfjs/issues/5447\n offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];\n d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /\n inChannels);\n\n if(d1 < inputShape[" + colDim + "] && d1 >= 0) {\n\n ch = imod(pos, inChannels);\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n ";
}
}
this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// can be expressed as matrix multiplication (without need for memory
// remapping).
function conv2dByMatMul(_ref) {
var x = _ref.x,
filter = _ref.filter,
convInfo = _ref.convInfo,
backend = _ref.backend,
_ref$bias = _ref.bias,
bias = _ref$bias === void 0 ? null : _ref$bias,
_ref$preluActivationW = _ref.preluActivationWeights,
preluActivationWeights = _ref$preluActivationW === void 0 ? null : _ref$preluActivationW,
_ref$leakyreluAlpha = _ref.leakyreluAlpha,
leakyreluAlpha = _ref$leakyreluAlpha === void 0 ? 0 : _ref$leakyreluAlpha,
_ref$activation = _ref.activation,
activation = _ref$activation === void 0 ? null : _ref$activation;
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
// result from 2D to 4D.
var xShape = x.shape;
var xTexData = backend.texData.get(x.dataId);
var sharedMatMulDim = convInfo.inChannels;
var outerShapeX = xShape[0] * xShape[1] * xShape[2];
var outerShapeFilter = convInfo.outChannels;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
var transposeA = false;
var transposeB = false;
var out;
var intermediates = []; // TODO: Once reduction ops are packed, batchMatMul will always be packed
// and we can remove this condition.
var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; // The algorithm in the if condition assumes (1) the output will be packed,
// (2) x is packed, (3) x isChannelsLast, (4) x's packed texture is already
// on GPU, (5) col is odd, (6) the width, height and inChannels are the same
// for xTexData.shape and xShape.
var canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
if (canOptimize) {
// We avoid expensive packed 2x2 reshape by padding col count to next,
// even number. When col is odd, the result of packed batchMatMul is
// the same (has the same texture layout and and values in the texture) as
// it is for next even col. We make the odd-cols tensor to look like
// even-cols tensor before the operation and, after the batchMatMul,
// fix the even-cols result to have odd number of cols.
var targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);
var xReshaped = {
dataId: x.dataId,
shape: [1, targetShape, convInfo.inChannels],
dtype: x.dtype
}; // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
// Decrementing col count, after batchMatMul->...->compileProgram leads to
// invalid col count within the reference in GPGPUBinary.inShapeInfos.
// Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
// in compileProgram method, but that would affect compilation of all
// programs - instead, provide a copy here, with even col count, before
// calling batchMatMul->...->compileProgram and after that, the original
// xTexData.shape is restored.
var originalXTexDataShape = xTexData.shape;
xTexData.shape = xTexData.shape.slice();
xTexData.shape[xTexData.shape.length - 2]++;
assert(isReshapeFree(xTexData.shape, xReshaped.shape), function () {
return "packed reshape " + xTexData.shape + " to " + xReshaped.shape + " isn't free";
});
var filterReshaped = reshape$3({
inputs: {
x: filter
},
backend: backend,
attrs: {
shape: [1, convInfo.inChannels, convInfo.outChannels]
}
});
intermediates.push(filterReshaped);
var pointwiseConv = batchMatMulImpl({
a: xReshaped,
b: filterReshaped,
backend: backend,
transposeA: transposeA,
transposeB: transposeB,
bias: bias,
activation: activation,
preluActivationWeights: preluActivationWeights,
leakyreluAlpha: leakyreluAlpha
});
var pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
assert(pointwiseConvTexData.isPacked, function () {
return 'batchMatMul result is expected to be packed';
}); // Restore the input shape to original.
xTexData.shape = originalXTexDataShape; // Set the output shape - there is no need for expensive reshape as data
// layout is already correct.
pointwiseConvTexData.shape = convInfo.outShape;
out = identity$2({
inputs: {
x: pointwiseConv
},
backend: backend
});
out.shape = convInfo.outShape;
intermediates.push(pointwiseConv);
} else {
var _targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] : xShape[0] * xShape[2] * xShape[3];
var _xReshaped = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: [1, _targetShape, convInfo.inChannels]
}
});
var _filterReshaped = reshape$3({
inputs: {
x: filter
},
backend: backend,
attrs: {
shape: [1, convInfo.inChannels, convInfo.outChannels]
}
});
var result = batchMatMulImpl({
a: _xReshaped,
b: _filterReshaped,
transposeA: transposeA,
transposeB: transposeB,
backend: backend,
bias: bias,
activation: activation,
preluActivationWeights: preluActivationWeights,
leakyreluAlpha: leakyreluAlpha
});
out = reshape$3({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: convInfo.outShape
}
});
intermediates.push(_xReshaped);
intermediates.push(_filterReshaped);
intermediates.push(result);
}
for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
var i = _intermediates[_i];
backend.disposeIntermediateTensorInfo(i);
}
return out;
} // Implements the im2row algorithm as outlined in "High Performance
// Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
function conv2dWithIm2Row(_ref2) {
var x = _ref2.x,
filter = _ref2.filter,
convInfo = _ref2.convInfo,
backend = _ref2.backend,
_ref2$bias = _ref2.bias,
bias = _ref2$bias === void 0 ? null : _ref2$bias,
_ref2$preluActivation = _ref2.preluActivationWeights,
preluActivationWeights = _ref2$preluActivation === void 0 ? null : _ref2$preluActivation,
_ref2$leakyreluAlpha = _ref2.leakyreluAlpha,
leakyreluAlpha = _ref2$leakyreluAlpha === void 0 ? 0 : _ref2$leakyreluAlpha,
_ref2$activation = _ref2.activation,
activation = _ref2$activation === void 0 ? null : _ref2$activation;
// Rearranges conv2d input so each block to be convolved over forms the
// column of a new matrix with shape [filterWidth * filterHeight *
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
// output channel forms a row of a new matrix with shape [outChannels,
// filterWidth * filterHeight * inChannels]. The convolution is then
// computed by multiplying these matrices and reshaping the result.
var filterWidth = convInfo.filterWidth,
filterHeight = convInfo.filterHeight,
inChannels = convInfo.inChannels,
outWidth = convInfo.outWidth,
outHeight = convInfo.outHeight,
dataFormat = convInfo.dataFormat;
var isChannelsLast = dataFormat === 'channelsLast';
var sharedDim = filterWidth * filterHeight * inChannels;
var numCols = outHeight * outWidth;
var x2ColShape = [sharedDim, numCols];
var transposeA = true;
var transposeB = false;
var intermediates = [];
var xSqueezed = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: x.shape.slice(1)
}
});
var w2Row = reshape$3({
inputs: {
x: filter
},
backend: backend,
attrs: {
shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim]
}
});
intermediates.push(xSqueezed);
intermediates.push(w2Row);
var im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
var customValues = [xSqueezed.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]];
var im2Col = backend.runWebGLProgram(im2ColProgram, [xSqueezed], 'float32', customValues);
var im2ColReshaped = reshape$3({
inputs: {
x: im2Col
},
backend: backend,
attrs: {
shape: [1, x2ColShape[0], x2ColShape[1]]
}
});
intermediates.push(im2Col);
intermediates.push(im2ColReshaped);
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var hasLeakyreluAlpha = activation === 'leakyrelu';
var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
var matmulProgram = new MatMulPackedProgram(im2ColReshaped.shape, w2Row.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
var inputs = [im2ColReshaped, w2Row];
if (bias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
var product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
var outShape = isChannelsLast ? [1, outHeight, outWidth, convInfo.outChannels] : [1, convInfo.outChannels, outHeight, outWidth];
var out = reshape$3({
inputs: {
x: product
},
backend: backend,
attrs: {
shape: outShape
}
});
intermediates.push(product);
for (var _i2 = 0, _intermediates2 = intermediates; _i2 < _intermediates2.length; _i2++) {
var i = _intermediates2[_i2];
backend.disposeIntermediateTensorInfo(i);
}
return out;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2d$4(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false
/* depthwise */
, $dataFormat);
var out;
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
out = conv2dByMatMul({
x: x,
filter: filter,
convInfo: convInfo,
backend: backend
});
} else if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
out = conv2dWithIm2Row({
x: x,
filter: filter,
convInfo: convInfo,
backend: backend
});
} else {
var program = new Conv2DProgram(convInfo);
out = backend.runWebGLProgram(program, [x, filter], 'float32');
}
var outReshaped = reshape$3({
inputs: {
x: out
},
backend: backend,
attrs: {
shape: convInfo.outShape
}
});
backend.disposeIntermediateTensorInfo(out);
return outReshaped;
}
var conv2DConfig$1 = {
kernelName: Conv2D,
backendName: 'webgl',
kernelFunc: conv2d$4
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Conv2DDerFilterProgram = function Conv2DDerFilterProgram(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
var Conv2DDerInputProgram = function Conv2DDerInputProgram(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
var rowDim = isChannelsLast ? 1 : 2;
var colDim = isChannelsLast ? 2 : 3;
var channelDim = isChannelsLast ? 3 : 1;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
var Conv3DDerFilterProgram = function Conv3DDerFilterProgram(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padFront = convInfo.padInfo.front;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
var Conv3DDerInputProgram = function Conv3DDerInputProgram(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
var filterDepth = convInfo.filterDepth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padFront = filterDepth - 1 - convInfo.padInfo.front;
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropFilter$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode,
filterShape = attrs.filterShape;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1
/* dilations */
, pad, dimRoundingMode, false
/* depthwise */
, $dataFormat);
var program = new Conv2DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
var conv2DBackpropFilterConfig$1 = {
kernelName: Conv2DBackpropFilter,
backendName: 'webgl',
kernelFunc: conv2DBackpropFilter$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv2DBackpropInput$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var inputShape = attrs.inputShape,
strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1
/* dilations */
, pad, dimRoundingMode, false, $dataFormat);
var program = new Conv2DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
var conv2DBackpropInputConfig$1 = {
kernelName: Conv2DBackpropInput,
backendName: 'webgl',
kernelFunc: conv2DBackpropInput$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3D$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
var convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
var program = new Conv3DProgram(convInfo);
return backend.runWebGLProgram(program, [x, filter], 'float32');
}
var conv3DConfig$1 = {
kernelName: Conv3D,
backendName: 'webgl',
kernelFunc: conv3D$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropFilterV2$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
pad = attrs.pad,
filterShape = attrs.filterShape;
var convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1
/* dilations */
, pad);
var program = new Conv3DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
var conv3DBackpropFilterV2Config$1 = {
kernelName: Conv3DBackpropFilterV2,
backendName: 'webgl',
kernelFunc: conv3DBackpropFilterV2$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function conv3DBackpropInput$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var pad = attrs.pad,
strides = attrs.strides,
inputShape = attrs.inputShape;
var convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1
/* dilations */
, pad);
var program = new Conv3DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
var conv3DBackpropInputConfig = {
kernelName: Conv3DBackpropInputV2,
backendName: 'webgl',
kernelFunc: conv3DBackpropInput$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var COS = CHECK_NAN_SNIPPET_UNARY + "\n return cos(x);\n";
var cos$2 = unaryKernelFunc$1({
opSnippet: COS
});
var cosConfig$1 = {
kernelName: Cos,
backendName: 'webgl',
kernelFunc: cos$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n";
var cosh$2 = unaryKernelFunc$1({
opSnippet: COSH
});
var coshConfig$1 = {
kernelName: Cosh,
backendName: 'webgl',
kernelFunc: cosh$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var CropAndResizeProgram = function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) {
this.variableNames = ['Image', 'Boxes', 'BoxInd'];
this.outputShape = [];
var batch = imageShape[0],
imageHeight = imageShape[1],
imageWidth = imageShape[2],
depth = imageShape[3];
var numBoxes = boxShape[0];
var cropHeight = cropSize[0],
cropWidth = cropSize[1];
this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
var methodId = method === 'bilinear' ? 1 : 0;
var inputHeightFloat = imageHeight - 1 + ".0",
inputWidthFloat = imageWidth - 1 + ".0";
var _ref = cropHeight > 1 ? ["" + (imageHeight - 1) / (cropHeight - 1), '(y2-y1) * height_ratio', "y1*" + inputHeightFloat + " + float(y)*(height_scale)"] : ['0.0', '0.0', "0.5 * (y1+y2) * " + inputHeightFloat],
heightRatio = _ref[0],
heightScale = _ref[1],
inY = _ref[2];
var _ref2 = cropWidth > 1 ? ["" + (imageWidth - 1) / (cropWidth - 1), '(x2-x1) * width_ratio', "x1*" + inputWidthFloat + " + float(x)*(width_scale)"] : ['0.0', '0.0', "0.5 * (x1+x2) * " + inputWidthFloat],
widthRatio = _ref2[0],
widthScale = _ref2[1],
inX = _ref2[2]; // Reference implementation
// tslint:disable-next-line:max-line-length
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var cropAndResize$2 = function cropAndResize(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var image = inputs.image,
boxes = inputs.boxes,
boxInd = inputs.boxInd;
var cropSize = attrs.cropSize,
method = attrs.method,
extrapolationValue = attrs.extrapolationValue;
var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
};
var cropAndResizeConfig$1 = {
kernelName: CropAndResize,
backendName: 'webgl',
kernelFunc: cropAndResize$2
};
var CumSumProgram = function CumSumProgram(shape, exclusive, reverse) {
this.variableNames = ['x'];
this.customUniforms = [{
name: 'index',
type: 'float'
}];
this.outputShape = shape;
var rank = shape.length;
var val = exclusive ? '0.0' : "getX(" + getCoords$1(rank, 'coords') + ")";
var length = shape[shape.length - 1];
var condition = '';
var idxString = ''; // When exclusive is set, the cumsum op becomes roll op that copies the
// value from the previous index based on the direction specified by the
// reverse flag.
if (exclusive) {
condition = reverse ? "end != " + (length - 1) : 'end != 0';
idxString = reverse ? 'end + 1' : 'end - 1';
} else {
condition = reverse ? "end + pow2 < " + length : 'end >= pow2';
idxString = reverse ? 'end + pow2' : 'end - pow2';
}
this.userCode = "\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, 'coords') + ";\n float val = " + val + ";\n int pow2 = int(pow(2.0, index));\n if (" + condition + ") {\n int idx = " + idxString + ";\n " + getFinalCoord(rank, 'coords') + " = idx;\n val += getX(" + getCoords$1(rank, 'coords') + ");\n }\n setOutput(val);\n }\n ";
};
function getCoords$1(rank, name) {
if (rank === 1) {
return "" + name;
} else if (rank === 2) {
return name + ".x, " + name + ".y";
} else if (rank === 3) {
return name + ".x, " + name + ".y, " + name + ".z";
} else if (rank === 4) {
return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w";
} else {
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
}
}
function getFinalCoord(rank, name) {
if (rank === 1) {
return "" + name;
} else if (rank === 2) {
return name + ".y";
} else if (rank === 3) {
return name + ".z";
} else if (rank === 4) {
return name + ".w";
} else {
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
}
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function cumsum$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
exclusive = attrs.exclusive,
reverse = attrs.reverse;
var xRank = x.shape.length;
var permutation = getAxesPermutation([axis], xRank);
var permutedX = x;
if (permutation != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutation
}
});
}
var permutedAxis = getInnerMostAxes(1, xRank)[0];
if (permutedAxis !== xRank - 1) {
throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.shape.length - 1) + " " + ("but got axis=" + axis));
}
var size = permutedX.shape[permutedAxis];
var result = identity$2({
inputs: {
x: permutedX
},
backend: backend
}); // Use cumsum parallel algorithm, ref:
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
for (var i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
var program = new CumSumProgram(permutedX.shape, false, reverse);
var customValues = [[i]];
var prevResult = result;
result = backend.runWebGLProgram(program, [result], result.dtype, customValues);
backend.disposeIntermediateTensorInfo(prevResult);
} // For exclusive cumsum, shift the end result in the direction of sum
// and add 0 to the front index.
if (exclusive) {
var _program = new CumSumProgram(permutedX.shape, exclusive, reverse);
var _prevResult = result;
result = backend.runWebGLProgram(_program, [result], result.dtype);
backend.disposeIntermediateTensorInfo(_prevResult);
}
if (permutation != null) {
var reversePermutation = getUndoAxesPermutation(permutation);
var reverseTransposedResult = transpose$2({
inputs: {
x: result
},
backend: backend,
attrs: {
perm: reversePermutation
}
});
backend.disposeIntermediateTensorInfo(result);
backend.disposeIntermediateTensorInfo(permutedX);
return reverseTransposedResult;
}
return result;
}
var cumsumConfig$1 = {
kernelName: Cumsum,
backendName: 'webgl',
kernelFunc: cumsum$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function denseBincount$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
weights = inputs.weights;
var size = attrs.size,
binaryOutput = attrs.binaryOutput;
if (x.shape.length === 1) {
var xVals = backend.readSync(x.dataId);
var weightsVals = backend.readSync(weights.dataId);
var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
} else if (x.shape.length === 2) {
var xBuf = backend.bufferSync(x);
var weightsBuf = backend.bufferSync(weights);
var outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
}
throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" + (x.shape.length + "."));
}
var denseBincountConfig$1 = {
kernelName: DenseBincount,
backendName: 'webgl',
kernelFunc: denseBincount$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthToSpaceProgram = /*#__PURE__*/function () {
function DepthToSpaceProgram(outputShape, blockSize, dataFormat) {
this.variableNames = ['x'];
this.outputShape = [];
this.outputShape = outputShape;
this.blockSize = blockSize;
this.dataFormat = dataFormat;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n ";
}
var _proto = DepthToSpaceProgram.prototype;
_proto.getHeightCoordString = function getHeightCoordString() {
if (this.dataFormat === 'NHWC') {
return "coords[1]";
} else {
return "coords[2]";
}
};
_proto.getWidthCoordString = function getWidthCoordString() {
if (this.dataFormat === 'NHWC') {
return "coords[2]";
} else {
return "coords[3]";
}
};
_proto.getDepthCoordString = function getDepthCoordString() {
if (this.dataFormat === 'NHWC') {
return "coords[3]";
} else {
return "coords[1]";
}
};
_proto.getOutputDepthSize = function getOutputDepthSize() {
if (this.dataFormat === 'NHWC') {
return this.outputShape[3];
} else {
return this.outputShape[1];
}
};
_proto.getInputSamplingString = function getInputSamplingString() {
if (this.dataFormat === 'NHWC') {
return "getX(b, in_h, in_w, in_d)";
} else {
return "getX(b, in_d, in_h, in_w)";
}
};
return DepthToSpaceProgram;
}();
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthToSpace$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockSize = attrs.blockSize,
dataFormat = attrs.dataFormat;
assert(blockSize > 1, function () {
return "blockSize should be > 1 for depthToSpace, but was: " + blockSize;
});
var batchSize = x.shape[0];
var inputHeight = dataFormat === 'NHWC' ? x.shape[1] : x.shape[2];
var inputWidth = dataFormat === 'NHWC' ? x.shape[2] : x.shape[3];
var inputDepth = dataFormat === 'NHWC' ? x.shape[3] : x.shape[1];
var outputHeight = inputHeight * blockSize;
var outputWidth = inputWidth * blockSize;
var outputDepth = inputDepth / (blockSize * blockSize);
var outputShape = dataFormat === 'NHWC' ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth];
var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
return backend.runWebGLProgram(program, [x], x.dtype);
}
var depthToSpaceConfig$1 = {
kernelName: DepthToSpace,
backendName: 'webgl',
kernelFunc: depthToSpace$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConv2DProgram = function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
if (hasLeakyReluAlpha === void 0) {
hasLeakyReluAlpha = false;
}
this.variableNames = ['x', 'W'];
this.customUniforms = [{
name: 'pads',
type: 'ivec2'
}, {
name: 'strides',
type: 'ivec2'
}, {
name: 'dilations',
type: 'ivec2'
}, {
name: 'inDims',
type: 'ivec2'
}];
this.outputShape = convInfo.outShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var channelMul = convInfo.outChannels / convInfo.inChannels;
var activationSnippet = '',
applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else if (hasLeakyReluAlpha) {
activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyReluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = "\n " + activationSnippet + "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * dilations[0];\n\n if (xR < 0 || xR >= inDims[0]) {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * dilations[1];\n\n if (xC < 0 || xC >= inDims[1]) {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConvPacked2DProgram = function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) {
if (addBias === void 0) {
addBias = false;
}
if (activation === void 0) {
activation = null;
}
if (hasPreluActivation === void 0) {
hasPreluActivation = false;
}
if (hasLeakyReluAlpha === void 0) {
hasLeakyReluAlpha = false;
}
this.variableNames = ['x', 'W'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{
name: 'pads',
type: 'ivec2'
}, {
name: 'strides',
type: 'ivec2'
}, {
name: 'dilations',
type: 'ivec2'
}, {
name: 'inDims',
type: 'ivec2'
}];
this.outputShape = convInfo.outShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
var channelMul = convInfo.outChannels / convInfo.inChannels;
var padLeft = convInfo.padInfo.left;
var strideWidth = convInfo.strideWidth;
var dilationWidth = convInfo.dilationWidth;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var texelsAcross = filterWidth;
var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;";
for (var c = 0; c < filterWidth; c++) {
mainLoop += "\n vec4 xTexelC" + c * 2 + ";\n int xTexelC" + c * 2 + "Ready;\n vec4 xTexelC" + (c * 2 + 1) + ";\n int xTexelC" + (c * 2 + 1) + "Ready;\n vec4 xC" + c + ";";
}
/**
* This vectorized implementation works by gathering the values needed for
* each output channel's dot product into vec4's and then multiplying them
* all together (this happens in the final double for-loop below). Most of
* the main loop consists of constructing these vec4's with the minimum
* number of texture2D calls, which means making use of all four returned
* values from a texture2D call at once.
*/
for (var r = 0; r < filterHeight; r++) {
for (var _c = 0; _c < filterWidth; _c++) {
mainLoop += "\n xTexelC" + _c * 2 + " = vec4(0.0);\n xTexelC" + _c * 2 + "Ready = 0;\n xTexelC" + (_c * 2 + 1) + " = vec4(0.0);\n xTexelC" + (_c * 2 + 1) + "Ready = 0;\n xC" + _c + " = vec4(0.0);";
}
mainLoop += "\n xR = xRCorner + " + r + " * dilations[0];\n if (xR >=0 && xR < inDims[0]) {\n ";
for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
var colIndex = texelC * 2;
mainLoop += "\n xC = xCCorner + " + colIndex * dilationWidth + ";\n ";
if (strideWidth === 1) {
if (colIndex < filterWidth) {
// If padding is odd, the outer texels have to be composed.
if (padLeft % 2 === 1) {
// TODO: Ensure vec4 previous does not result in redundant sample,
// and avoid setting xTexelRC's that exceed the boundary in the
// first place rather than resetting them to vec4(0)).
// To compute xCOffset:
// - If padding is odd, we must add 1 to ensure we ask for an
// even-numbered row.
// - We subtract 2 to access the previous texel.
mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + colIndex + "Ready == 0) {\n xTexelC" + colIndex + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC" + colIndex + ".zw = vec2(0.0);\n }\n xTexelC" + colIndex + "Ready = 1;\n }\n "; // This texel has been read in previous iteration if the dilation
// is 1.
if (dilationWidth === 1 && colIndex > 0) {
mainLoop += "\n xC" + colIndex + " = vec4(xTexelC" + (colIndex - 2) + ".zw, xTexelC" + colIndex + ".xy);\n ";
} else {
mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < inDims[1]) {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n previous.zw = vec2(0.0);\n }\n\n xC" + colIndex + " = vec4(previous.zw, xTexelC" + colIndex + ".xy);\n } else {\n xC" + colIndex + " = vec4(0.0, 0.0, xTexelC" + colIndex + ".xy);\n }\n ";
}
} else {
// Padding is even, so xRC corresponds to a single texel.
mainLoop += "\n if (xC >= 0 && xC < inDims[1] && xTexelC" + colIndex + "Ready == 0) {\n xTexelC" + colIndex + " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC" + colIndex + ".zw = vec2(0.0);\n }\n xTexelC" + colIndex + "Ready = 1;\n }\n\n xC" + colIndex + " = xTexelC" + colIndex + ";\n ";
}
if (colIndex + 1 < filterWidth) {
// If dilation is even, the second entry should match the first
// (either both are composed or both are single samples). But if
// dilation is odd, then the second entry should be the opposite
// of the first (if the first is composed, the second is a single
// sample, and vice versa.)
var nextTexelOffset = padLeft % 2 === 0 ? nearestLargerEven(dilationWidth) : dilationWidth;
if (dilationWidth % 2 === 0 && padLeft % 2 === 1 || dilationWidth % 2 !== 0 && padLeft % 2 !== 1) {
mainLoop += "\n xCOffset = xC + imod(pads[1], 2) + " + nextTexelOffset + ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + (colIndex + 1) + "Ready == 0) {\n xTexelC" + (colIndex + 1) + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC" + (colIndex + 1) + ".zw = vec2(0.0);\n }\n xTexelC" + (colIndex + 1) + "Ready = 1;\n }\n "; // If dilation > 1 then the xRC's will not be able to share any
// values, so each xRC will require two unique calls to getX.
if (dilationWidth > 1) {
mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + colIndex + "Ready == 0) {\n xTexelC" + colIndex + " = getX(batch, xR, xCOffset, d1);\n xTexelC" + colIndex + "Ready = 1;\n }\n ";
}
mainLoop += "\n xC" + (colIndex + 1) + " = vec4(xTexelC" + colIndex + ".zw, xTexelC" + (colIndex + 1) + ".xy);\n ";
} else {
// If dilation is 1 and padding is odd, we have already read the
// texel when constructing the previous x value. Here we can
// simply skip the texture read.
if (nextTexelOffset === 1) {
mainLoop += "\n xC" + (colIndex + 1) + " = xTexelC" + colIndex + ";\n ";
} else {
mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + (colIndex + 1) + "Ready == 0) {\n xTexelC" + (colIndex + 1) + " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC" + (colIndex + 1) + ".zw = vec2(0.0);\n }\n xTexelC" + (colIndex + 1) + "Ready = 1;\n }\n\n xC" + (colIndex + 1) + " = xTexelC" + (colIndex + 1) + ";\n ";
}
}
}
}
} else {
// stride === 2
if (colIndex < filterWidth) {
// Depending on whether padLeft is even or odd, we want either the
// xy or zw channels from X texels for xC${colIndex}. If padLeft is
// even, xC${colIndex +1} is simply the zw channels of texels we've
// already sampled. But if padLeft is odd, xC{$c + 1}.zw will
// need to come from the xy channels of a new texel, hence the `
// vec4
// final` initialized below.
if (padLeft % 2 === 1) {
mainLoop += "\n xCOffset = xC + 1 - strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + colIndex + "Ready == 0) {\n xTexelC" + colIndex + " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC" + colIndex + ".zw = vec2(0.0);\n }\n xTexelC" + colIndex + "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC" + (colIndex + 1) + "Ready == 0) {\n xTexelC" + (colIndex + 1) + " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= inDims[1]) {\n xTexelC" + (colIndex + 1) + ".zw = vec2(0.0);\n }\n xTexelC" + (colIndex + 1) + "Ready = 1;\n }\n\n xC" + colIndex + " = vec4(xTexelC" + colIndex + ".zw, xTexelC" + (colIndex + 1) + ".zw);\n ";
if (colIndex + 1 < filterWidth) {
mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1]) {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC" + (colIndex + 1) + " = vec4(xTexelC" + (colIndex + 1) + ".xy, final.xy);\n ";
}
} else {
mainLoop += "\n if(xC >= 0 && xC < inDims[1] && xTexelC" + colIndex + "Ready == 0) {\n xTexelC" + colIndex + " = getX(batch, xR, xC, d1);\n if (xC + 1 >= inDims[1]) {\n xTexelC" + colIndex + ".zw = vec2(0.0);\n }\n xTexelC" + colIndex + "Ready = 1;\n }\n\n xCOffset = xC + strides[1];\n if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC" + (colIndex + 1) + "Ready == 0) {\n xTexelC" + (colIndex + 1) + " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= inDims[1]) {\n xTexelC" + (colIndex + 1) + ".zw = vec2(0.);\n }\n xTexelC" + (colIndex + 1) + "Ready = 1;\n }\n\n xC" + colIndex + " = vec4(\n xTexelC" + colIndex + ".xy, xTexelC" + (colIndex + 1) + ".xy);\n ";
if (colIndex + 1 < filterWidth) {
mainLoop += "\n xC" + (colIndex + 1) + " = vec4(xTexelC" + colIndex + ".zw, xTexelC" + (colIndex + 1) + ".zw);\n ";
}
}
}
} // localize the dotProd accumulation within the loop, the theory is for
// GPU with limited cache, accumulate sum across large amount of
// veriables will cause lots of cache misses. (i.e. 5x5 filter will have
// 50 variables)
if (colIndex < filterWidth) {
mainLoop += "\n wTexel = getW(" + r + ", " + colIndex + ", d1, q);\n dotProd += xC" + colIndex + " * vec4(wTexel.xz, wTexel.xz);\n ";
if (colIndex + 1 < filterWidth) {
mainLoop += "\n wTexel = getW(" + r + ", " + (colIndex + 1) + ", d1, q);\n dotProd += xC" + (colIndex + 1) + " * vec4(wTexel.xz, wTexel.xz);\n ";
}
}
}
mainLoop += "\n }\n ";
}
var activationSnippet = '',
applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
} else if (hasLeakyReluAlpha) {
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
} else {
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
}
applyActivationSnippet = "result = activation(result);";
}
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyReluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = "\n " + activationSnippet + "\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n " + mainLoop + "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNative$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode;
var $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
return 'Error in depthwiseConv2d: Either strides or dilations must be ' + ("1. Got strides " + strides + " and dilations '" + $dilations + "'");
});
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true
/* depthwise */
);
var program;
if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) {
program = new DepthwiseConvPacked2DProgram(convInfo);
} else {
program = new DepthwiseConv2DProgram(convInfo);
}
var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
}
var depthwiseConv2dNativeConfig$1 = {
kernelName: DepthwiseConv2dNative,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNative$1
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DepthwiseConv2DDerFilterProgram = function DepthwiseConv2DDerFilterProgram(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = convInfo.padInfo.top;
var padLeft = convInfo.padInfo.left;
var channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
var DepthwiseConv2DDerInputProgram = function DepthwiseConv2DDerInputProgram(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
var filterHeight = convInfo.filterHeight;
var filterWidth = convInfo.filterWidth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var padTop = filterHeight - 1 - convInfo.padInfo.top;
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
var channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropFilter$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
dy = inputs.dy;
var strides = attrs.strides,
dilations = attrs.dilations,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
filterShape = attrs.filterShape;
var convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true
/* depthwise */
);
var program = new DepthwiseConv2DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
var depthwiseConv2dNativeBackpropFilterConfig$1 = {
kernelName: DepthwiseConv2dNativeBackpropFilter,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNativeBackpropFilter$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function depthwiseConv2dNativeBackpropInput$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
filter = inputs.filter;
var strides = attrs.strides,
dilations = attrs.dilations,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode,
inputShape = attrs.inputShape;
var convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true
/* depthwise */
);
var program = new DepthwiseConv2DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
var depthwiseConv2dNativeBackpropInputConfig$1 = {
kernelName: DepthwiseConv2dNativeBackpropInput,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNativeBackpropInput$2
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var DiagProgram = function DiagProgram(size) {
this.variableNames = ['X'];
this.outputShape = [size, size];
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function diag$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
var outShape = [].concat(x.shape, x.shape);
var xSize = sizeFromShape(x.shape);
var flat = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: [xSize]
}
});
var program = new DiagProgram(xSize);
var res = backend.runWebGLProgram(program, [flat], flat.dtype);
var out = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo(flat);
backend.disposeIntermediateTensorInfo(res);
return out;
}
var diagConfig$1 = {
kernelName: Diag,
backendName: 'webgl',
kernelFunc: diag$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var Dilation2DProgram = function Dilation2DProgram(convInfo) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
var inHeight = convInfo.inHeight,
inWidth = convInfo.inWidth,
padInfo = convInfo.padInfo,
strideHeight = convInfo.strideHeight,
strideWidth = convInfo.strideWidth,
filterHeight = convInfo.filterHeight,
filterWidth = convInfo.filterWidth,
dilationHeight = convInfo.dilationHeight,
dilationWidth = convInfo.dilationWidth;
var padTop = padInfo.top,
padLeft = padInfo.left;
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float neg_infinity = -3.4e38;\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.w;\n ivec2 outTopLeftCorner =\n coords.yz * strides - pads;\n int hBeg = outTopLeftCorner.x;\n int wBeg = outTopLeftCorner.y;\n\n float curVal = neg_infinity;\n for (int h = 0; h < " + filterHeight + "; h++) {\n int hIn = hBeg + h * " + dilationHeight + ";\n\n if (hIn >= 0 && hIn < " + inHeight + ") {\n for (int w = 0; w < " + filterWidth + "; w++) {\n int wIn = wBeg + w * " + dilationWidth + ";\n\n if (wIn >= 0 && wIn < " + inWidth + ") {\n float xVal = getX(batch, hIn, wIn, d1);\n float wVal = getW(h, w, d1);\n\n float val = xVal + wVal;\n if (val > curVal) {\n curVal = val;\n }\n }\n }\n }\n }\n\n float result = curVal;\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function dilation2D(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations;
var convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC'
/* dataFormat */
, dilations);
var out;
var program = new Dilation2DProgram(convInfo);
out = backend.runWebGLProgram(program, [x, filter], 'float32');
var outReshaped = reshape$3({
inputs: {
x: out
},
backend: backend,
attrs: {
shape: convInfo.outShape
}
});
backend.disposeIntermediateTensorInfo(out);
return outReshaped;
}
var dilation2DConfig = {
kernelName: Dilation2D,
backendName: 'webgl',
kernelFunc: dilation2D
};
function einsum$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var equation = attrs.equation;
var tensors = inputs;
var _backend_util$decodeE = decodeEinsumEquation(equation, tensors.length),
allDims = _backend_util$decodeE.allDims,
summedDims = _backend_util$decodeE.summedDims,
idDims = _backend_util$decodeE.idDims;
checkEinsumDimSizes(allDims.length, idDims, tensors);
var _backend_util$getEins = getEinsumComputePath(summedDims, idDims),
path = _backend_util$getEins.path,
steps = _backend_util$getEins.steps;
var nSteps = steps.length;
var out = null;
var numDimsRemaining = allDims.length;
var tensorsToDispose = [];
for (var i = 0; i < nSteps; ++i) {
for (var _iterator = _createForOfIteratorHelperLoose(steps[i]), _step; !(_step = _iterator()).done;) {
var idTerm = _step.value;
var _backend_util$getEins2 = getEinsumPermutation(numDimsRemaining, idDims[idTerm]),
perm = _backend_util$getEins2.permutationIndices,
dimsToExpand = _backend_util$getEins2.expandDims;
var x = void 0;
if (isIdentityPermutation(perm)) {
x = tensors[idTerm];
} else {
x = transpose$2({
inputs: {
x: tensors[idTerm]
},
backend: backend,
attrs: {
perm: perm
}
});
tensorsToDispose.push(x);
}
var targetShape = x.shape.slice();
for (var k = 0; k < dimsToExpand.length; ++k) {
targetShape.splice(dimsToExpand[k], 0, 1);
}
if (!arraysEqual(x.shape, targetShape)) {
x = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: targetShape
}
});
tensorsToDispose.push(x);
}
if (out === null) {
out = x;
} else {
// tslint:disable-next-line: no-unnecessary-type-assertion
out = multiply$4({
inputs: {
a: x,
b: out
},
backend: backend
});
tensorsToDispose.push(out);
}
}
if (i < nSteps - 1) {
if (path[i] >= 0) {
out = sum$4({
inputs: {
x: out
},
backend: backend,
attrs: {
axis: path[i] - (allDims.length - numDimsRemaining),
keepDims: false
}
});
tensorsToDispose.push(out);
}
numDimsRemaining--;
}
} // Clean up intermediate tensors.
for (var _i = 0, _tensorsToDispose = tensorsToDispose; _i < _tensorsToDispose.length; _i++) {
var tensorInfo = _tensorsToDispose[_i];
if (tensorInfo === out) {
continue;
}
backend.disposeIntermediateTensorInfo(tensorInfo);
}
return out;
}
var einsumConfig$1 = {
kernelName: Einsum,
backendName: 'webgl',
kernelFunc: einsum$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ELU$3 = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
var ELU_PACKED = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
var elu$4 = unaryKernelFunc$1({
opSnippet: ELU$3,
packedOpSnippet: ELU_PACKED
});
var eluConfig$1 = {
kernelName: Elu,
backendName: 'webgl',
kernelFunc: elu$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ELU_DER$1 = "return (b >= 1.0) ? a : a * (b + 1.0);";
var ELU_DER_PACKED = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
var eluGrad$1 = function eluGrad(args) {
var inputs = args.inputs,
backend = args.backend;
var dy = inputs.dy,
y = inputs.y;
var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER$1, dy.shape, y.shape);
return backend.runWebGLProgram(program, [dy, y], dy.dtype);
};
var eluGradConfig$2 = {
kernelName: EluGrad,
backendName: 'webgl',
kernelFunc: eluGrad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PACKED_EQUAL = "\n return vec4(equal(a, b));\n";
var EQUAL = "return float(a == b);";
var equal$2 = binaryKernelFunc$1({
opSnippet: EQUAL,
packedOpSnippet: PACKED_EQUAL,
dtype: 'bool',
cpuKernelImpl: equalImplCPU
});
var equalConfig$1 = {
kernelName: Equal,
backendName: 'webgl',
kernelFunc: equal$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ERF = "\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = " + ERF_P + ";\n float a1 = " + ERF_A1 + ";\n float a2 = " + ERF_A2 + ";\n float a3 = " + ERF_A3 + ";\n float a4 = " + ERF_A4 + ";\n float a5 = " + ERF_A5 + ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n";
var erf$2 = unaryKernelFunc$1({
opSnippet: ERF
});
var erfConfig$1 = {
kernelName: Erf,
backendName: 'webgl',
kernelFunc: erf$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EXP = "return exp(x);";
var exp$5 = unaryKernelFunc$1({
opSnippet: EXP,
packedOpSnippet: EXP,
cpuKernelImpl: expImplCPU
});
var expConfig$1 = {
kernelName: Exp,
backendName: 'webgl',
kernelFunc: exp$5
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function expandDims$3(args) {
var inputs = args.inputs,
attrs = args.attrs,
backend = args.backend;
var dim = attrs.dim;
var input = inputs.input;
var inputRank = input.shape.length;
var newShape = input.shape.slice();
var $dim = dim;
if (dim < 0) {
// Negative value is counted from the tail of rank.
assert(-(inputRank + 1) <= dim, function () {
return "Axis must be in the interval [" + -(inputRank + 1) + ", " + inputRank + "]";
});
$dim = inputRank + dim + 1;
}
newShape.splice($dim, 0, 1);
return reshape$3({
inputs: {
x: input
},
backend: backend,
attrs: {
shape: newShape
}
});
}
var expandDimsConfig$1 = {
kernelName: ExpandDims,
backendName: 'webgl',
kernelFunc: expandDims$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var EXPM1 = "return exp(x) - 1.0;";
var expm1$2 = unaryKernelFunc$1({
opSnippet: EXPM1,
packedOpSnippet: EXPM1,
cpuKernelImpl: expm1ImplCPU
});
var expm1Config$1 = {
kernelName: Expm1,
backendName: 'webgl',
kernelFunc: expm1$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FFTProgram = function FFTProgram(component, inputShape, inverse) {
this.variableNames = ['real', 'imag'];
var innerDim = inputShape[1];
this.outputShape = inputShape;
var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI;
var resultDenominator = inverse ? innerDim + ".0" : '1.0';
var opString;
if (component === 'real') {
opString = 'return real * expR - imag * expI;';
} else if (component === 'imag') {
opString = 'return real * expI + imag * expR;';
} else {
throw new Error("FFT component must be either \"real\" or \"imag\", got " + component + ".");
}
this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + opString + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fftImpl$1(x, inverse, backend) {
var xData = backend.texData.get(x.dataId);
var inputSize = sizeFromShape(x.shape); // Collapse all outer dimensions to a single batch dimension.
var innerDimensionSize = x.shape[x.shape.length - 1];
var batch = inputSize / innerDimensionSize;
var input2D = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: [batch, innerDimensionSize]
}
});
var xShape = input2D.shape;
var realProgram = new FFTProgram('real', xShape, inverse);
var imagProgram = new FFTProgram('imag', xShape, inverse);
var inputs = [{
dataId: xData.complexTensorInfos.real.dataId,
dtype: xData.complexTensorInfos.real.dtype,
shape: xShape
}, {
dataId: xData.complexTensorInfos.imag.dataId,
dtype: xData.complexTensorInfos.imag.dtype,
shape: xShape
}];
var realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
var imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
var complexOutput = complex$2({
inputs: {
real: realPart,
imag: imagPart
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart);
var complexOutputReshaped = reshape$3({
inputs: {
x: complexOutput
},
backend: backend,
attrs: {
shape: x.shape
}
});
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(complexOutput);
return complexOutputReshaped;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fft$2(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
return fftImpl$1(input, false
/* inverse */
, backend);
}
var fftConfig$1 = {
kernelName: FFT,
backendName: 'webgl',
kernelFunc: fft$2
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FillProgram = function FillProgram(shape, value) {
this.outputShape = [];
this.customUniforms = [{
name: 'value',
type: 'float'
}];
this.variableNames = ['x'];
this.outputShape = shape;
this.userCode = "\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fill$2(args) {
var backend = args.backend,
attrs = args.attrs;
var shape = attrs.shape,
value = attrs.value;
var dtype = attrs.dtype;
dtype = dtype || inferDtype(value);
if (dtype === 'string') {
// String type should be handled in CPU memory.
var values = getArrayFromDType(dtype, sizeFromShape(shape));
values.fill(value);
return backend.makeTensorInfo(shape, dtype, values);
} else {
var program = new FillProgram(shape, value);
var customValues = [[value]];
return backend.runWebGLProgram(program, [], dtype, customValues);
}
}
var fillConfig$1 = {
kernelName: Fill,
backendName: 'webgl',
kernelFunc: fill$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FlipLeftRightProgram = function FlipLeftRightProgram(imageShape) {
this.variableNames = ['Image'];
this.outputShape = [];
var imageWidth = imageShape[2];
this.outputShape = imageShape;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n\n int coordX = " + imageWidth + " - x - 1;\n float outputValue;\n if(coordX >= 0 && coordX < " + imageWidth + ") {\n outputValue = getImage(coords[0], coords[1], coordX, coords[3]);\n } else {\n outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);\n }\n setOutput(outputValue);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var flipLeftRightConfig$1 = {
kernelName: FlipLeftRight,
backendName: 'webgl',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend;
var image = inputs.image;
var webglBackend = backend;
var program = new FlipLeftRightProgram(image.shape);
var output = webglBackend.runWebGLProgram(program, [image], image.dtype);
return output;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FLOOR = "return floor(x);";
var floor$c = unaryKernelFunc$1({
opSnippet: FLOOR,
packedOpSnippet: FLOOR,
cpuKernelImpl: floorImplCPU
});
var floorConfig$1 = {
kernelName: Floor,
backendName: 'webgl',
kernelFunc: floor$c
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// we implement floor division and glsl implements truncated division, we
// correct for this by subtracting 1 from result when the result is negative and
// there is a remainder.
var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n";
var INT_DIV_PACKED = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n";
var floorDiv$2 = binaryKernelFunc$1({
opSnippet: INT_DIV,
packedOpSnippet: INT_DIV_PACKED,
dtype: 'int32'
});
var floorDivConfig$1 = {
kernelName: FloorDiv,
backendName: 'webgl',
kernelFunc: floorDiv$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FromPixelsProgram = function FromPixelsProgram(outputShape) {
this.variableNames = ['A'];
var glsl = getGlslDifferences();
var height = outputShape[0],
width = outputShape[1];
this.outputShape = outputShape;
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n ";
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var FromPixelsPackedProgram = function FromPixelsPackedProgram(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
var glsl = getGlslDifferences();
var height = outputShape[0],
width = outputShape[1];
this.outputShape = outputShape;
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n " + glsl.output + " = result;\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var fromPixelsConfig = {
kernelName: FromPixels,
backendName: 'webgl',
kernelFunc: fromPixels$1
};
var fromPixels2DContext$1;
function fromPixels$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var pixels = inputs.pixels;
var numChannels = attrs.numChannels;
var isVideo = typeof HTMLVideoElement !== 'undefined' && pixels instanceof HTMLVideoElement;
var isImage = typeof HTMLImageElement !== 'undefined' && pixels instanceof HTMLImageElement;
var _ref = isVideo ? [pixels.videoWidth, pixels.videoHeight] : [pixels.width, pixels.height],
width = _ref[0],
height = _ref[1];
var texShape = [height, width];
var outShape = [height, width, numChannels];
if (isImage || isVideo) {
if (fromPixels2DContext$1 == null) {
fromPixels2DContext$1 = document.createElement('canvas').getContext('2d');
}
fromPixels2DContext$1.canvas.width = width;
fromPixels2DContext$1.canvas.height = height;
fromPixels2DContext$1.drawImage(pixels, 0, 0, width, height);
pixels = fromPixels2DContext$1.canvas;
}
var tempPixelHandle = backend.makeTensorInfo(texShape, 'int32'); // This is a byte texture with pixels.
backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
var program = env().getBool('WEBGL_PACK') ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape);
var res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
backend.disposeData(tempPixelHandle.dataId);
return res;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedConv2d(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
var $dataFormat = convertConv2DDataFormat(dataFormat);
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false
/* depthwise */
, $dataFormat);
var out;
var intermediates = [];
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
out = conv2dByMatMul({
x: x,
filter: filter,
convInfo: convInfo,
backend: backend,
bias: bias,
activation: activation,
preluActivationWeights: preluActivationWeights,
leakyreluAlpha: leakyreluAlpha
});
} else if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
out = conv2dWithIm2Row({
x: x,
filter: filter,
convInfo: convInfo,
backend: backend,
bias: bias,
activation: activation,
preluActivationWeights: preluActivationWeights,
leakyreluAlpha: leakyreluAlpha
});
} else {
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var hasLeakyreluAlpha = activation === 'leakyrelu';
var fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
var program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
var _inputs = [x, filter];
if (bias) {
_inputs.push(bias);
}
if (preluActivationWeights) {
_inputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
_inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
out = backend.runWebGLProgram(program, _inputs, 'float32');
}
var outReshaped = reshape$3({
inputs: {
x: out
},
backend: backend,
attrs: {
shape: convInfo.outShape
}
});
intermediates.push(out);
intermediates.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return outReshaped;
}
var fusedConv2DConfig$1 = {
kernelName: FusedConv2D,
backendName: 'webgl',
kernelFunc: fusedConv2d
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function fusedDepthwiseConv2D$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
filter = inputs.filter,
bias = inputs.bias,
preluActivationWeights = inputs.preluActivationWeights;
var strides = attrs.strides,
pad = attrs.pad,
dilations = attrs.dilations,
dimRoundingMode = attrs.dimRoundingMode,
activation = attrs.activation,
leakyreluAlpha = attrs.leakyreluAlpha;
var intermediates = [];
var $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert(eitherStridesOrDilationsAreOne(strides, $dilations), function () {
return 'Error in depthwiseConv2d: Either strides or dilations must be ' + ("1. Got strides " + strides + " and dilations '" + $dilations + "'");
});
var convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true
/* depthwise */
);
var shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1;
var fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null;
var programInputs = [x, filter];
var hasBias = bias != null;
var hasPreluActivationWeights = preluActivationWeights != null;
var hasLeakyreluAlpha = activation === 'leakyrelu';
if (hasBias) {
programInputs.push(bias);
}
if (hasPreluActivationWeights) {
programInputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
programInputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
var program;
if (shouldPackDepthwiseConv) {
program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
} else {
program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
}
var customValues = [[convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth]];
var result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
intermediates.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var fusedDepthwiseConv2DConfig$1 = {
kernelName: FusedDepthwiseConv2D,
backendName: 'webgl',
kernelFunc: fusedDepthwiseConv2D$1
};
var GatherNDProgram = function GatherNDProgram(sliceDim, strides, shape) {
this.sliceDim = sliceDim;
this.strides = strides;
this.variableNames = ['x', 'indices'];
this.outputShape = shape;
var stridesType = getCoordsDataType(strides.length);
var dtype = getCoordsDataType(shape.length);
var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherNd$1(args) {
var inputs = args.inputs,
backend = args.backend;
var params = inputs.params,
indices = inputs.indices;
var indicesShape = indices.shape;
var sliceRank = indicesShape[indicesShape.length - 1];
var paramsSize = sizeFromShape(params.shape);
var _backend_util$prepare = prepareAndValidate(params, indices),
resultShape = _backend_util$prepare[0],
numSlices = _backend_util$prepare[1],
sliceSize = _backend_util$prepare[2],
strides = _backend_util$prepare[3];
var flattenIndices = reshape$3({
inputs: {
x: indices
},
backend: backend,
attrs: {
shape: [numSlices, sliceRank]
}
});
var flattenX = reshape$3({
inputs: {
x: params
},
backend: backend,
attrs: {
shape: [sizeFromShape(params.shape) / sliceSize, sliceSize]
}
});
if (backend.shouldExecuteOnCPU([params, indices]) || params.dtype === 'string') {
var indicesData = backend.readSync(indices.dataId);
var paramsBuf = backend.bufferSync(params);
var outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
}
var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
var res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
var reshaped = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: resultShape
}
});
backend.disposeIntermediateTensorInfo(flattenIndices);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(res);
return reshaped;
}
var gatherNdConfig$1 = {
kernelName: GatherNd,
backendName: 'webgl',
kernelFunc: gatherNd$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GatherProgram = function GatherProgram(aShape, outputShape) {
this.variableNames = ['A', 'indices'];
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var sourceCoords = getSourceCoords$1(aShape, 2);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
}; // The input and output are always flattened into rank 4 tensors.
function getSourceCoords$1(aShape, axis) {
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
var sourceCoords = [];
for (var i = 0; i < aShape.length; i++) {
if (i === 2) {
sourceCoords.push('int(getIndices(resRC.x, resRC.z))');
} else {
sourceCoords.push("" + currentCoords[i]);
}
}
return sourceCoords.join();
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function gatherV2$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
indices = inputs.indices;
var axis = attrs.axis,
batchDims = attrs.batchDims;
var parsedAxis = parseAxisParam(axis, x.shape)[0];
var shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
var indicesSize = sizeFromShape(indices.shape);
var toDispose = [];
var flattenX = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: [shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize]
}
});
var flattenIndex = reshape$3({
inputs: {
x: indices
},
backend: backend,
attrs: {
shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize]
}
});
toDispose.push(flattenX);
toDispose.push(flattenIndex);
var flattenOutputShape = [shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize];
if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
var indicesBuf = backend.bufferSync(flattenIndex);
var xBuf = backend.bufferSync(flattenX);
var outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
}
var program = new GatherProgram(flattenX.shape, flattenOutputShape);
var res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
toDispose.push(res);
var reshaped = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: shapeInfo.outputShape
}
});
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return reshaped;
}
var gatherV2Config$1 = {
kernelName: GatherV2,
backendName: 'webgl',
kernelFunc: gatherV2$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GREATER = "return float(a > b);";
var GREATER_PACKED = "\n return vec4(greaterThan(a, b));\n";
var greater$3 = binaryKernelFunc$1({
opSnippet: GREATER,
packedOpSnippet: GREATER_PACKED,
cpuKernelImpl: greaterImplCPU,
dtype: 'bool'
});
var greaterConfig$1 = {
kernelName: Greater,
backendName: 'webgl',
kernelFunc: greater$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var GREATER_EQUAL = "return float(a >= b);";
var GREATER_EQUAL_PACKED = "\n return vec4(greaterThanEqual(a, b));\n";
var greaterEqual$2 = binaryKernelFunc$1({
opSnippet: GREATER_EQUAL,
packedOpSnippet: GREATER_EQUAL_PACKED,
dtype: 'bool',
cpuKernelImpl: greaterEqualImplCPU
});
var greaterEqualConfig$1 = {
kernelName: GreaterEqual,
backendName: 'webgl',
kernelFunc: greaterEqual$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function ifft$2(args) {
var inputs = args.inputs,
backend = args.backend;
var input = inputs.input;
return fftImpl$1(input, true
/* inverse */
, backend);
}
var ifftConfig$1 = {
kernelName: IFFT,
backendName: 'webgl',
kernelFunc: ifft$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var IS_FINITE = "return float(!isnan(x) && !isinf(x));";
var isFinite$3 = unaryKernelFunc$1({
opSnippet: IS_FINITE,
dtype: 'bool'
});
var isFiniteConfig$1 = {
kernelName: IsFinite,
backendName: 'webgl',
kernelFunc: isFinite$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var IS_INF = "return float(isinf(x));";
var isInf$2 = unaryKernelFunc$1({
opSnippet: IS_INF,
dtype: 'bool'
});
var isInfConfig$1 = {
kernelName: IsInf,
backendName: 'webgl',
kernelFunc: isInf$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var IS_NAN = "return float(isnan(x));";
var isNaN$3 = unaryKernelFunc$1({
opSnippet: IS_NAN,
dtype: 'bool'
});
var isNaNConfig$1 = {
kernelName: IsNan,
backendName: 'webgl',
kernelFunc: isNaN$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LESS = "return float(a < b);";
var LESS_PACKED = "\n return vec4(lessThan(a, b));\n";
var less$3 = binaryKernelFunc$1({
opSnippet: LESS,
packedOpSnippet: LESS_PACKED,
cpuKernelImpl: lessImplCPU,
dtype: 'bool'
});
var lessConfig$1 = {
kernelName: Less,
backendName: 'webgl',
kernelFunc: less$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LESS_EQUAL = "return float(a <= b);";
var LESS_EQUAL_PACKED = "\n return vec4(lessThanEqual(a, b));\n";
var lessEqual$2 = binaryKernelFunc$1({
opSnippet: LESS_EQUAL,
packedOpSnippet: LESS_EQUAL_PACKED,
cpuKernelImpl: lessEqualImplCPU,
dtype: 'bool'
});
var lessEqualConfig$1 = {
kernelName: LessEqual,
backendName: 'webgl',
kernelFunc: lessEqual$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function linSpace$1(args) {
var backend = args.backend,
attrs = args.attrs;
var start = attrs.start,
stop = attrs.stop,
num = attrs.num; // TODO: Use CPU implementation due to the precision problem in Safari.
var outVals = linSpaceImplCPU(start, stop, num);
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
}
var linSpaceConfig$1 = {
kernelName: LinSpace,
backendName: 'webgl',
kernelFunc: linSpace$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LOG = "if (x < 0.0) return NAN;\n return log(x);";
var LOG_PACKED = "\n vec4 result = log(x);\n vec4 isNaN = vec4(lessThan(x, vec4(0.0)));\n result.r = isNaN.r == 1.0 ? NAN : result.r;\n result.g = isNaN.g == 1.0 ? NAN : result.g;\n result.b = isNaN.b == 1.0 ? NAN : result.b;\n result.a = isNaN.a == 1.0 ? NAN : result.a;\n\n return result;\n";
var log$c = unaryKernelFunc$1({
opSnippet: LOG,
packedOpSnippet: LOG_PACKED,
cpuKernelImpl: logImplCPU
});
var logConfig$1 = {
kernelName: Log,
backendName: 'webgl',
kernelFunc: log$c
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LOG1P = "return log(1.0 + x);";
var log1p$2 = unaryKernelFunc$1({
opSnippet: LOG1P
});
var log1pConfig$1 = {
kernelName: Log1p,
backendName: 'webgl',
kernelFunc: log1p$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);";
var LOGICAL_AND_PACKED = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n";
var logicalAnd$2 = binaryKernelFunc$1({
opSnippet: LOGICAL_AND,
packedOpSnippet: LOGICAL_AND_PACKED,
dtype: 'bool'
});
var logicalAndConfig$1 = {
kernelName: LogicalAnd,
backendName: 'webgl',
kernelFunc: logicalAnd$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LOGICAL_NOT = "return float(!(x >= 1.0));";
var logicalNot$2 = unaryKernelFunc$1({
opSnippet: LOGICAL_NOT
});
var logicalNotConfig$1 = {
kernelName: LogicalNot,
backendName: 'webgl',
kernelFunc: logicalNot$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);";
var LOGICAL_OR_PACKED = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n";
var logicalOr$2 = binaryKernelFunc$1({
opSnippet: LOGICAL_OR,
packedOpSnippet: LOGICAL_OR_PACKED,
dtype: 'bool'
});
var logicalOrConfig$1 = {
kernelName: LogicalOr,
backendName: 'webgl',
kernelFunc: logicalOr$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNProgram = function LRNProgram(xShape, radius, bias, alpha, beta) {
this.variableNames = ['x'];
this.outputShape = [];
var rad = radius;
var maxD = xShape[3] - 1;
this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta)
// src: https://github.com/tensorflow/tensorflow/..
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
var powOperator;
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
if (beta === 0.5) {
powOperator = "inversesqrt(" + basis + ")";
} else if (beta === 1.0) {
powOperator = "1.0/(" + basis + ")";
} else {
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNPackedProgram = function LRNPackedProgram(xShape, radius, bias, alpha, beta) {
this.variableNames = ['x'];
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
var rad = radius;
var maxD = xShape[3] - 1;
this.outputShape = xShape; // optimize pow(bias + alpha * sum, -beta)
// src: https://github.com/tensorflow/tensorflow/..
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
var powOperator;
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
if (beta === 0.5) {
powOperator = "inversesqrt(" + basis + ")";
} else if (beta === 1.0) {
powOperator = "1.0/(" + basis + ")";
} else {
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lrn = function lrn(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var depthRadius = attrs.depthRadius,
bias = attrs.bias,
alpha = attrs.alpha,
beta = attrs.beta;
var program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) : new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
return backend.runWebGLProgram(program, [x], x.dtype);
}; // tslint:disable-next-line: variable-name
var LRNConfig = {
kernelName: LRN,
backendName: 'webgl',
kernelFunc: lrn
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var LRNGradProgram = function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) {
this.variableNames = ['inputImage', 'outputImage', 'dy'];
this.outputShape = [];
this.outputShape = inputShape;
this.depth = inputShape[3];
this.depthRadius = depthRadius;
this.bias = bias;
this.alpha = alpha;
this.beta = beta;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var lrnGrad = function lrnGrad(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
y = inputs.y,
dy = inputs.dy;
var depthRadius = attrs.depthRadius,
bias = attrs.bias,
alpha = attrs.alpha,
beta = attrs.beta;
var program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
}; // tslint:disable-next-line: variable-name
var LRNGradConfig = {
kernelName: LRNGrad,
backendName: 'webgl',
kernelFunc: lrnGrad
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxImpl$1(x, reduceShape, outShape, backend) {
var inSize = sizeFromShape(reduceShape);
var xSize = sizeFromShape(x.shape);
var batchSize = xSize / inSize;
var reshapedInput = reshape$3({
inputs: {
x: x
},
attrs: {
shape: [batchSize, inSize]
},
backend: backend
});
var reduced = reduce(reshapedInput, x.dtype, 'max', backend);
var reshapedOutput = reshape$3({
inputs: {
x: reduced
},
attrs: {
shape: outShape
},
backend: backend
});
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function max$8(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var reductionIndices = attrs.reductionIndices,
keepDims = attrs.keepDims;
var xRank = x.shape.length;
var origAxes = parseAxisParam(reductionIndices, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var maxInputIsTransposed = permutedAxes != null;
var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
var maxInput = x;
if (maxInputIsTransposed) {
if (shouldExecuteOnCPU) {
var xTexData = backend.texData.get(maxInput.dataId);
var values = xTexData.values;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
var maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
maxInput = backend.makeTensorInfo(newShape, x.dtype);
var maxInputData = backend.texData.get(maxInput.dataId);
maxInputData.values = maxInputValues;
} else {
maxInput = transposeImpl$1(x, permutedAxes, backend);
}
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('max', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(maxInput.shape, axes),
maxOutShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outShape = maxOutShape;
if (keepDims) {
// rather than reshape at the end, set the target shape here.
outShape = expandShapeToKeepDim(maxOutShape, origAxes);
}
var out;
if (shouldExecuteOnCPU) {
var _xTexData = backend.texData.get(maxInput.dataId);
var _values = _xTexData.values;
var outValues = maxImplCPU(_values, sizeFromShape(reduceShape), outShape, x.dtype);
out = backend.makeTensorInfo(outShape, x.dtype);
var outData = backend.texData.get(out.dataId);
outData.values = outValues;
} else {
out = maxImpl$1(maxInput, reduceShape, outShape, backend);
}
if (maxInputIsTransposed) {
backend.disposeIntermediateTensorInfo(maxInput);
}
return out;
}
var maxConfig$1 = {
kernelName: Max,
backendName: 'webgl',
kernelFunc: max$8
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MAXIMUM = CHECK_NAN_SNIPPET$1 + "\n return max(a, b);\n";
var MAXIMUM_PACKED = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET$2 + "\n return result;\n";
var maximum$4 = binaryKernelFunc$1({
opSnippet: MAXIMUM,
packedOpSnippet: MAXIMUM_PACKED,
cpuKernelImpl: maximumImplCPU
});
var maximumConfig$1 = {
kernelName: Maximum,
backendName: 'webgl',
kernelFunc: maximum$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
assertNotComplex$1(x, 'maxPool');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = 1;
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in maxPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity$2({
inputs: {
x: x
},
backend: backend
});
}
var maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
}
var maxPoolConfig$1 = {
kernelName: MaxPool,
backendName: 'webgl',
kernelFunc: maxPool$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3d$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dataFormat = attrs.dataFormat,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = [1, 1, 1];
var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
var maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
}
var maxPool3DConfig$1 = {
kernelName: MaxPool3D,
backendName: 'webgl',
kernelFunc: maxPool3d$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MaxPool2DBackpropProgram = function MaxPool2DBackpropProgram(convInfo) {
this.variableNames = ['dy', 'maxPos'];
this.outputShape = convInfo.inShape;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationHeight = convInfo.dilationHeight;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n ";
};
var MaxPool3DBackpropProgram = function MaxPool3DBackpropProgram(convInfo) {
this.variableNames = ['dy', 'maxPos'];
this.outputShape = convInfo.inShape;
var strideDepth = convInfo.strideDepth;
var strideHeight = convInfo.strideHeight;
var strideWidth = convInfo.strideWidth;
var dilationDepth = convInfo.dilationDepth;
var dilationHeight = convInfo.dilationHeight;
var dilationWidth = convInfo.dilationWidth;
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPool3DGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input;
var x = input;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var dilations = [1, 1, 1];
var convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true
/* get positions */
);
var maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
var maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
var result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
backend.disposeIntermediateTensorInfo(maxPool3dPositions);
return result;
}
var maxPoolGrad3DConfig = {
kernelName: MaxPool3DGrad,
backendName: 'webgl',
kernelFunc: maxPool3DGrad$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolGrad$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var dy = inputs.dy,
input = inputs.input,
output = inputs.output;
var x = input;
assertNotComplex$1([input, output], 'maxPoolGrad');
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
dimRoundingMode = attrs.dimRoundingMode;
var convInfo = computePool2DInfo(x.shape, filterSize, strides, 1
/* dilations */
, pad, dimRoundingMode);
var getPositions = true;
var maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
var maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
var result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
backend.disposeIntermediateTensorInfo(maxPoolPositions);
return result;
}
var maxPoolGradConfig$2 = {
kernelName: MaxPoolGrad,
backendName: 'webgl',
kernelFunc: maxPoolGrad$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, backend) {
var program = new Pool2DProgram(convInfo, 'max', false);
var poolOutput = backend.runWebGLProgram(program, [x], 'float32');
program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
var indexOutput = backend.runWebGLProgram(program, [x], 'float32');
return [poolOutput, indexOutput];
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var maxPoolWithArgmaxConfig$1 = {
kernelName: MaxPoolWithArgmax,
backendName: 'webgl',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var x = inputs.x;
var filterSize = attrs.filterSize,
strides = attrs.strides,
pad = attrs.pad,
includeBatchInIndex = attrs.includeBatchInIndex;
var webglBackend = backend;
assert(x.shape.length === 4, function () {
return "Error in maxPool: input must be rank 4 but got rank " + x.shape.length + ".";
});
var dilations = [1, 1];
assert(eitherStridesOrDilationsAreOne(strides, dilations), function () {
return 'Error in maxPool: Either strides or dilations must be 1. ' + ("Got strides " + strides + " and dilations '" + dilations + "'");
});
var convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
var _maxPoolWithArgmaxImp = maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, webglBackend),
result = _maxPoolWithArgmaxImp[0],
indexes = _maxPoolWithArgmaxImp[1];
return [result, indexes];
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function meanImpl(x, reduceShape, outShape, backend) {
var inSize = sizeFromShape(reduceShape);
var xSize = sizeFromShape(x.shape);
var batchSize = xSize / inSize;
var reshapedInput = reshape$3({
inputs: {
x: x
},
attrs: {
shape: [batchSize, inSize]
},
backend: backend
});
var reduced = reduce(reshapedInput, 'float32', 'mean', backend);
var reshapedOutput = reshape$3({
inputs: {
x: reduced
},
attrs: {
shape: outShape
},
backend: backend
});
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var meanConfig$1 = {
kernelName: Mean,
backendName: 'webgl',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var x = inputs.x;
var keepDims = attrs.keepDims,
axis = attrs.axis;
var webglBackend = backend;
var xRank = x.shape.length;
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var meanInputIsTransposed = permutedAxes != null;
var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
var intermediates = [];
var meanInput = x;
if (meanInputIsTransposed) {
if (shouldExecuteOnCPU) {
var xTexData = webglBackend.texData.get(meanInput.dataId);
var values = xTexData.values;
var newShape = new Array(xRank);
for (var i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
var meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
var meanInputData = webglBackend.texData.get(meanInput.dataId);
meanInputData.values = meanInputValues;
} else {
meanInput = transposeImpl$1(x, permutedAxes, webglBackend);
}
intermediates.push(meanInput);
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('sum', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(meanInput.shape, axes),
meanOutShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var outShape = meanOutShape;
if (keepDims) {
// rather than reshape at the end, set the target shape here.
outShape = expandShapeToKeepDim(meanOutShape, origAxes);
}
var out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
for (var _i = 0, _intermediates = intermediates; _i < _intermediates.length; _i++) {
var _i2 = _intermediates[_i];
webglBackend.disposeIntermediateTensorInfo(_i2);
}
return out;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function min$c(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
var xRank = x.shape.length;
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var permutedX = x;
if (permutedAxes != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('min', axes, xRank);
var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var inSize = sizeFromShape(reduceShape);
var a2D = reshape$3({
inputs: {
x: permutedX
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
var reduced = reduce(a2D, a2D.dtype, 'min', backend);
var res;
if (keepDims) {
var newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: newShape
}
});
} else {
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: outShape
}
});
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
var minConfig$1 = {
kernelName: Min,
backendName: 'webgl',
kernelFunc: min$c
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MINIMUM = CHECK_NAN_SNIPPET$1 + "\n return min(a, b);\n";
var MINIMUM_PACKED = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + CHECK_NAN_SNIPPET$2 + "\n return result;\n";
var minimum$4 = binaryKernelFunc$1({
opSnippet: MINIMUM,
packedOpSnippet: MINIMUM_PACKED,
cpuKernelImpl: minimumImplCPU
});
var minimumConfig$1 = {
kernelName: Minimum,
backendName: 'webgl',
kernelFunc: minimum$4
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MirrorPadProgram = function MirrorPadProgram(xShape, paddings, mode) {
this.variableNames = ['x'];
this.outputShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ xShape[i] + p[1];
}
/* afterPad */
);
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function (p) {
return p[0];
}).join(',');
var end = paddings.map(function (p, i) {
return p[0] + xShape[i];
}).join(',');
var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
var offset = mode === 'reflect' ? 0 : 1;
if (rank === 1) {
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start) {\n outC = start * 2 - outC - " + offset + ";\n } else if(outC >= end) {\n outC = (end - 1) * 2 - outC + " + offset + ";\n }\n setOutput(getX(outC - start));\n }\n ";
return;
}
this.userCode = "\n " + dtype + " start = " + dtype + "(" + start + ");\n " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outC = getOutputCoords();\n for (int i = 0; i < " + rank + "; i++) {\n if (outC[i] < start[i]) {\n outC[i] = start[i] * 2 - outC[i] - " + offset + ";\n } else if(outC[i] >= end[i]) {\n outC[i] = (end[i] - 1) * 2 - outC[i] + " + offset + ";\n }\n }\n " + dtype + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Example shader code for
* `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
* ```
* const int start = int(2);
* const int end = int(5);
*
* void main() {
* int outputLoc = getOutputCoords();
* vec4 result = vec4(0.);
*
* int rc = outputLoc;
*
* int source = rc;
* if (source < start) {
* source = start * 2 - source - 0;
* } else if (source >= end) {
* source = (end - 1) * 2 - source + 0;
* }
* source -= start;
*
* result[0] = getChannel(getX(source), source);
* rc += 1;
* if(rc < 6) {
* int source = rc;
* if (source < start) {
* source = start * 2 - source - 0;
* } else if (source >= end) {
* source = (end - 1) * 2 - source + 0;
* }
* source -= start;
*
* result[1] = getChannel(getX(source), source);
* }
*
* setOutput(result);
* }
* ```
*/
var MirrorPadPackedProgram = function MirrorPadPackedProgram(xShape, paddings, mode) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ xShape[i] + p[1];
}
/* afterPad */
);
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function (p) {
return p[0];
}).join(',');
var end = paddings.map(function (p, i) {
return p[0] + xShape[i];
}).join(',');
var coords = getChannels('rc', rank);
var source = getChannels('source', rank);
var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1];
var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")";
var offset = mode === 'reflect' ? 0 : 1;
var mainLoop = '';
if (rank === 1) {
var padSetup = "\n " + dtype + " source = rc;\n if (source < start) {\n source = start * 2 - source - " + offset + ";\n } else if (source >= end) {\n source = (end - 1) * 2 - source + " + offset + ";\n }\n source -= start;\n ";
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
} else {
var _padSetup = "\n " + dtype + " source = rc;\n " + dtype + " lt = " + dtype + "(lessThan(source, start));\n " + dtype + " gte = " + dtype + "(greaterThanEqual(source, end));\n " + dtype + " orig = 1 - (lt + gte);\n source = orig * source +\n lt * (start * 2 - source - " + offset + ") +\n gte * ((end - 1) * 2 - source + " + offset + ");\n source -= start;\n ";
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + _padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + _padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {\n " + _padSetup + "\n result[2] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + _padSetup + "\n result[3] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n }\n ";
}
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var mirrorPadKernelFunc = function mirrorPadKernelFunc(_ref) {
var inputs = _ref.inputs,
backend = _ref.backend,
attrs = _ref.attrs;
var x = inputs.x;
var paddings = attrs.paddings,
mode = attrs.mode;
var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode);
var output = backend.runWebGLProgram(program, [x], x.dtype);
return output;
};
var mirrorPadConfig$1 = {
kernelName: MirrorPad,
backendName: 'webgl',
kernelFunc: mirrorPadKernelFunc
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);";
var MOD_PACKED = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " + CHECK_NAN_SNIPPET$2 + "\n return result;\n";
var mod$2 = binaryKernelFunc$1({
opSnippet: MOD,
packedOpSnippet: MOD_PACKED
});
var modConfig$1 = {
kernelName: Mod,
backendName: 'webgl',
kernelFunc: mod$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var MultinomialProgram = function MultinomialProgram(batchSize, numOutcomes, numSamples) {
this.variableNames = ['probs'];
this.customUniforms = [{
name: 'seed',
type: 'float'
}];
this.outputShape = [batchSize, numSamples];
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
// floored can cause errors.
var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;"; // We do the same as in ./binaryop_gpu, with vec4 and ivec4.
// On Linux, the vectorized implementation produces NaNs when a and b are 0.
var DIV_PACKED = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n";
var realDiv = binaryKernelFunc$1({
opSnippet: DIV,
packedOpSnippet: DIV_PACKED,
checkOutOfBounds: true
});
var realDivConfig$1 = {
kernelName: RealDiv,
backendName: 'webgl',
kernelFunc: realDiv
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SUB = 'return a - b;';
var sub$2 = binaryKernelFunc$1({
opSnippet: SUB,
packedOpSnippet: SUB,
supportsComplex: true,
cpuKernelImpl: subImplCPU
});
var subConfig$1 = {
kernelName: Sub,
backendName: 'webgl',
kernelFunc: sub$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function softmax$3(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var logits = inputs.logits;
var dim = attrs.dim;
var axes = parseAxisParam([dim], logits.shape);
var maxLogit = max$8({
inputs: {
x: logits
},
backend: backend,
attrs: {
reductionIndices: axes,
keepDims: false
}
});
var expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
var maxLogitsReshaped = reshape$3({
inputs: {
x: maxLogit
},
backend: backend,
attrs: {
shape: expandedShape
}
});
var a = sub$2({
inputs: {
a: logits,
b: maxLogitsReshaped
},
backend: backend
});
var b = exp$5({
inputs: {
x: a
},
backend: backend
});
var sumExp = sum$4({
inputs: {
x: b
},
backend: backend,
attrs: {
axis: axes,
keepDims: false
}
});
var sumExpReshaped = reshape$3({
inputs: {
x: sumExp
},
backend: backend,
attrs: {
shape: expandedShape
}
});
var res = realDiv({
inputs: {
a: b,
b: sumExpReshaped
},
backend: backend
});
backend.disposeIntermediateTensorInfo(maxLogit);
backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
backend.disposeIntermediateTensorInfo(a);
backend.disposeIntermediateTensorInfo(b);
backend.disposeIntermediateTensorInfo(sumExp);
backend.disposeIntermediateTensorInfo(sumExpReshaped);
return res;
}
var softmaxConfig$1 = {
kernelName: Softmax,
backendName: 'webgl',
kernelFunc: softmax$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function multinomial$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var logits = inputs.logits;
var numSamples = attrs.numSamples,
seed = attrs.seed,
normalized = attrs.normalized;
var probs = normalized ? logits : softmax$3({
inputs: {
logits: logits
},
backend: backend,
attrs: {
dim: logits.shape.length - 1
}
});
var batchSize = probs.shape[0];
var numOutcomes = probs.shape[1];
var program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
var customValues = [[seed]];
var res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
if (!normalized) {
backend.disposeIntermediateTensorInfo(probs);
}
return res;
}
var multinomialConfig$1 = {
kernelName: Multinomial,
backendName: 'webgl',
kernelFunc: multinomial$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var NEG = "return -x;"; // This doesn't use unaryKernelFunc because negImplCPU is not of type
// SimpleUnaryKernelImplCPU.
function neg$2(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
if (backend.shouldExecuteOnCPU([x])) {
var xData = backend.texData.get(x.dataId);
var _negImplCPU = negImplCPU(xData.values, x.shape, x.dtype),
outValues = _negImplCPU[0],
newShape = _negImplCPU[1];
return backend.makeTensorInfo(newShape, x.dtype, outValues);
}
var program;
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
program = new UnaryOpPackedProgram(x.shape, NEG);
} else {
program = new UnaryOpProgram(x.shape, NEG);
}
return backend.runWebGLProgram(program, [x], x.dtype);
}
var negConfig$1 = {
kernelName: Neg,
backendName: 'webgl',
kernelFunc: neg$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV3Impl$2 = nonMaxSuppressionV3Impl;
function nonMaxSuppressionV3$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold;
var boxesVals = backend.readSync(boxes.dataId);
var scoresVals = backend.readSync(scores.dataId);
var _nonMaxSuppressionV3I = nonMaxSuppressionV3Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold),
selectedIndices = _nonMaxSuppressionV3I.selectedIndices;
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
}
var nonMaxSuppressionV3Config$1 = {
kernelName: NonMaxSuppressionV3,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV3$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV4Impl$2 = nonMaxSuppressionV4Impl;
function nonMaxSuppressionV4$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold,
padToMaxOutputSize = attrs.padToMaxOutputSize;
var boxesVals = backend.readSync(boxes.dataId);
var scoresVals = backend.readSync(scores.dataId);
var _nonMaxSuppressionV4I = nonMaxSuppressionV4Impl$2(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize),
selectedIndices = _nonMaxSuppressionV4I.selectedIndices,
validOutputs = _nonMaxSuppressionV4I.validOutputs;
return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))];
}
var nonMaxSuppressionV4Config$1 = {
kernelName: NonMaxSuppressionV4,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV4$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var nonMaxSuppressionV5Impl$2 = nonMaxSuppressionV5Impl;
function nonMaxSuppressionV5$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead');
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var boxes = inputs.boxes,
scores = inputs.scores;
var maxOutputSize = attrs.maxOutputSize,
iouThreshold = attrs.iouThreshold,
scoreThreshold = attrs.scoreThreshold,
softNmsSigma = attrs.softNmsSigma;
var boxesVals = backend.readSync(boxes.dataId);
var scoresVals = backend.readSync(scores.dataId);
var maxOutputSizeVal = maxOutputSize;
var iouThresholdVal = iouThreshold;
var scoreThresholdVal = scoreThreshold;
var softNmsSigmaVal = softNmsSigma;
var _nonMaxSuppressionV5I = nonMaxSuppressionV5Impl$2(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal),
selectedIndices = _nonMaxSuppressionV5I.selectedIndices,
selectedScores = _nonMaxSuppressionV5I.selectedScores;
return [backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))];
}
var nonMaxSuppressionV5Config$1 = {
kernelName: NonMaxSuppressionV5,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV5$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var OneHotProgram = function OneHotProgram(numIndices, depth, onValue, offValue) {
this.variableNames = ['indices'];
this.outputShape = [numIndices, depth];
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var oneHot$3 = function oneHot(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var indices = inputs.indices;
var depth = attrs.depth,
onValue = attrs.onValue,
offValue = attrs.offValue;
var indicesSize = sizeFromShape(indices.shape);
var program = new OneHotProgram(indicesSize, depth, onValue, offValue);
var reshaped = reshape$3({
inputs: {
x: indices
},
backend: backend,
attrs: {
shape: [indicesSize]
}
});
var result = backend.runWebGLProgram(program, [reshaped], indices.dtype);
backend.disposeIntermediateTensorInfo(reshaped);
var outShape = [].concat(indices.shape, [depth]);
var out = reshape$3({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo(result);
return out;
};
var oneHotConfig$1 = {
kernelName: OneHot,
backendName: 'webgl',
kernelFunc: oneHot$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function zerosLike$3(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
if (x.dtype === 'complex64') {
var realPart = real$2({
inputs: {
input: x
},
backend: backend
});
var r = zerosLike$3({
inputs: {
x: realPart
},
backend: backend
});
var imagPart = imag$2({
inputs: {
input: x
},
backend: backend
});
var i = zerosLike$3({
inputs: {
x: imagPart
},
backend: backend
});
var result = complex$2({
inputs: {
real: r,
imag: i
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
} else {
return fill$2({
attrs: {
shape: x.shape,
dtype: x.dtype,
value: x.dtype === 'string' ? '' : 0
},
backend: backend
});
}
}
var zerosLikeConfig$1 = {
kernelName: ZerosLike,
backendName: 'webgl',
kernelFunc: zerosLike$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function onesLike$3(args) {
var inputs = args.inputs,
backend = args.backend;
var x = inputs.x;
if (x.dtype === 'string') {
throw new Error('onesLike is not supported under string dtype');
} else if (x.dtype === 'complex64') {
var realPart = real$2({
inputs: {
input: x
},
backend: backend
});
var r = onesLike$3({
inputs: {
x: realPart
},
backend: backend
});
var imagPart = imag$2({
inputs: {
input: x
},
backend: backend
});
var i = zerosLike$3({
inputs: {
x: imagPart
},
backend: backend
});
var result = complex$2({
inputs: {
real: r,
imag: i
},
backend: backend
});
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
} else {
// TODO(cais, smilkov): Add WebGL shader for onesLike:
// https://github.com/tensorflow/tfjs/issues/1293
return fill$2({
attrs: {
shape: x.shape,
dtype: x.dtype,
value: 1
},
backend: backend
});
}
}
var onesLikeConfig$1 = {
kernelName: OnesLike,
backendName: 'webgl',
kernelFunc: onesLike$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function pack$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var axis = attrs.axis;
if (inputs.length === 1) {
return expandDims$3({
inputs: {
input: inputs[0]
},
backend: backend,
attrs: {
dim: axis
}
});
}
var shape = inputs[0].shape;
var dtype = inputs[0].dtype;
inputs.forEach(function (t) {
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
assert(dtype === t.dtype, function () {
return 'All tensors passed to stack must have matching dtypes';
});
});
var intermediateTensorInfos = [];
var expandedTensors = inputs.map(function (t) {
var expandedT = expandDims$3({
inputs: {
input: t
},
backend: backend,
attrs: {
dim: axis
}
});
intermediateTensorInfos.push(expandedT);
return expandedT;
});
var result = concat$2({
inputs: expandedTensors,
backend: backend,
attrs: {
axis: axis
}
});
intermediateTensorInfos.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var packConfig$1 = {
kernelName: Pack,
backendName: 'webgl',
kernelFunc: pack$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PadProgram = function PadProgram(xShape, paddings, constantValue) {
this.variableNames = ['x'];
this.customUniforms = [{
name: 'value',
type: 'float'
}];
this.outputShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ xShape[i] + p[1];
}
/* afterPad */
);
var rank = xShape.length;
var type = getCoordsDataType(rank);
var start = paddings.map(function (p) {
return p[0];
}).join(',');
var end = paddings.map(function (p, i) {
return p[0] + xShape[i];
}).join(',');
var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
if (rank === 1) {
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(value);\n } else {\n setOutput(getX(outC - start));\n }\n }\n ";
return;
}
this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(value);\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var PadPackedProgram = function PadPackedProgram(xShape, paddings, constantValue) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{
name: 'value',
type: 'float'
}];
this.outputShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ xShape[i] + p[1];
}
/* afterPad */
);
var rank = xShape.length;
var dtype = getCoordsDataType(rank);
var start = paddings.map(function (p) {
return p[0];
}).join(',');
var end = paddings.map(function (p, i) {
return p[0] + xShape[i];
}).join(',');
var coords = getChannels('rc', rank);
var source = getChannels('source', rank);
var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1];
var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")";
var componentSetup = [dtype + " rc = outputLoc;", coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ", rank === 1 ? '' : "}\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {", rank === 1 ? '' : " " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {"];
var paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
var mainLoop = '';
for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(value);\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
}
mainLoop += rank === 1 ? "} " : "}}";
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var padV2$1 = function padV2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var paddings = attrs.paddings,
constantValue = attrs.constantValue;
if (sizeFromShape(x.shape) === 0) {
// Short-circuit the computation, since x doesn't have value, only
// the shape is used to compute output shape to pad.
var outputShape = paddings.map(function (p, i) {
return p[0]
/* beforePad */
+ x.shape[i] + p[1];
}
/* afterPad */
);
return fill$2({
backend: backend,
attrs: {
shape: outputShape,
value: constantValue,
dtype: x.dtype
}
});
}
var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue);
var customValues = [[constantValue]];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
};
var padV2Config$1 = {
kernelName: PadV2,
backendName: 'webgl',
kernelFunc: padV2$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var POW = "\n if(a < 0.0 && floor(b) < b){\n return NAN;\n }\n if (b == 0.0) {\n return 1.0;\n }\n return (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n";
var POW_PACKED = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " + CHECK_NAN_SNIPPET$2 + "\n return result;\n";
var pow$8 = binaryKernelFunc$1({
opSnippet: POW,
packedOpSnippet: POW_PACKED
});
var powConfig$1 = {
kernelName: Pow,
backendName: 'webgl',
kernelFunc: pow$8
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function prod$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var axis = attrs.axis,
keepDims = attrs.keepDims;
var xRank = x.shape.length;
var toDispose = [];
var origAxes = parseAxisParam(axis, x.shape);
var axes = origAxes;
var permutedAxes = getAxesPermutation(axes, xRank);
var permutedX = x;
if (permutedAxes != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutedAxes
}
});
axes = getInnerMostAxes(axes.length, xRank);
toDispose.push(permutedX);
}
assertAxesAreInnerMostDims('prod', axes, xRank);
var res;
if (backend.shouldExecuteOnCPU([permutedX])) {
var xVals = backend.texData.get(permutedX.dataId).values;
var _prodImplCPU = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes),
outVals = _prodImplCPU.outVals,
outShape = _prodImplCPU.outShape,
outDtype = _prodImplCPU.outDtype;
res = backend.makeTensorInfo(outShape, outDtype, outVals);
} else {
var _backend_util$compute = computeOutAndReduceShapes(permutedX.shape, axes),
_outShape = _backend_util$compute[0],
reduceShape = _backend_util$compute[1];
var inSize = sizeFromShape(reduceShape);
var a2D = reshape$3({
inputs: {
x: permutedX
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
var outputDType = sumOutType(x.dtype);
var reduced = reduce(a2D, outputDType, 'prod', backend);
res = reshape$3({
inputs: {
x: reduced
},
backend: backend,
attrs: {
shape: _outShape
}
});
toDispose.push(a2D);
toDispose.push(reduced);
}
if (keepDims) {
toDispose.push(res);
var newShape = expandShapeToKeepDim(res.shape, origAxes);
res = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: newShape
}
});
}
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return res;
}
var prodConfig$1 = {
kernelName: Prod,
backendName: 'webgl',
kernelFunc: prod$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var range$3 = function range(args) {
var backend = args.backend,
attrs = args.attrs;
var start = attrs.start,
stop = attrs.stop,
step = attrs.step,
dtype = attrs.dtype;
var values = rangeImplCPU(start, stop, step, dtype);
return backend.makeTensorInfo([values.length], dtype, values);
};
var rangeConfig$1 = {
kernelName: Range,
backendName: 'webgl',
kernelFunc: range$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RECIPROCAL = "return 1.0 / x;";
var reciprocal$2 = unaryKernelFunc$1({
opSnippet: RECIPROCAL
});
var reciprocalConfig$1 = {
kernelName: Reciprocal,
backendName: 'webgl',
kernelFunc: reciprocal$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RELU$2 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : x;\n";
var RELU_PACKED = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var relu$2 = unaryKernelFunc$1({
opSnippet: RELU$2,
packedOpSnippet: RELU_PACKED
});
var reluConfig$1 = {
kernelName: Relu,
backendName: 'webgl',
kernelFunc: relu$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RELU6$2 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
var RELU6_PACKED = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
var relu6$2 = unaryKernelFunc$1({
opSnippet: RELU6$2,
packedOpSnippet: RELU6_PACKED
});
var relu6Config$1 = {
kernelName: Relu6,
backendName: 'webgl',
kernelFunc: relu6$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearProgram = function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.outputShape = [];
var batch = inputShape[0],
oldHeight = inputShape[1],
oldWidth = inputShape[2],
depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
var sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = "(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + " - vec2(0.5)";
} else {
sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
}
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearPackedProgram = function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
var batch = inputShape[0],
oldHeight = inputShape[1],
oldWidth = inputShape[2],
depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth];
var sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = "(vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC - vec3(0.5)";
} else {
sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
}
this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeBilinear$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images;
var alignCorners = attrs.alignCorners,
halfPixelCenters = attrs.halfPixelCenters,
size = attrs.size;
var newHeight = size[0],
newWidth = size[1];
var program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
return backend.runWebGLProgram(program, [images], 'float32');
}
var resizeBilinearConfig$1 = {
kernelName: ResizeBilinear,
backendName: 'webgl',
kernelFunc: resizeBilinear$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeBilinearBackpropProgram = function ResizeBilinearBackpropProgram(dyShape, inputShape, alignCorners) {
this.variableNames = ['dy'];
this.outputShape = [];
this.outputShape = inputShape;
var xHeight = inputShape[1],
xWidth = inputShape[2];
var yHeight = dyShape[1],
yWidth = dyShape[2]; // In the backwards pass, we want to find the pixels that were generated for
// each pixel in the input image the forward pass and add the corresponding
// coefficient from dy to the gradient (with some interpolation).
var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular
// index in dy that we want to search for contributions to dx.
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeBilinearGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images,
dy = inputs.dy;
var alignCorners = attrs.alignCorners;
var program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
return backend.runWebGLProgram(program, [dy], dy.dtype);
}
var resizeBilinearGradConfig$2 = {
kernelName: ResizeBilinearGrad,
backendName: 'webgl',
kernelFunc: resizeBilinearGrad$1
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeNearestNeighborProgram = function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.outputShape = [];
var batch = inputShape[0],
oldHeight = inputShape[1],
oldWidth = inputShape[2],
depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth]; // When align corners is false, we rounds the value with floor.
var roundBase = alignCorners ? '0.5' : '0.0';
var sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = "max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" + ", vec2(0.0))";
} else {
sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
}
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeNearestNeighborPackedProgram = function ResizeNearestNeighborPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
var batch = inputShape[0],
oldHeight = inputShape[1],
oldWidth = inputShape[2],
depth = inputShape[3];
this.outputShape = [batch, newHeight, newWidth, depth];
var effectiveInSize = [alignCorners && newHeight > 1 ? oldHeight - 1 : oldHeight, alignCorners && newWidth > 1 ? oldWidth - 1 : oldWidth];
var effectiveOutSize = [alignCorners && newHeight > 1 ? newHeight - 1 : newHeight, alignCorners && newWidth > 1 ? newWidth - 1 : newWidth]; // When align corners is false, we rounds the value with floor.
var roundBase = alignCorners ? '0.5' : '0.0';
var sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = "max((vec3(yRC) + vec3(0.5)) * " + "effectiveInputOverOutputRatioRC, vec3(0.0))";
} else {
sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
}
this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec3 sourceNearestRC = ivec3(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n vec4 newValue = vec4(\n getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),\n hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);\n\n setOutput(newValue);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeNearestNeighbor$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images;
var alignCorners = attrs.alignCorners,
halfPixelCenters = attrs.halfPixelCenters,
size = attrs.size;
var newHeight = size[0],
newWidth = size[1];
var program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
return backend.runWebGLProgram(program, [images], images.dtype);
}
var resizeNearestNeighborConfig$1 = {
kernelName: ResizeNearestNeighbor,
backendName: 'webgl',
kernelFunc: resizeNearestNeighbor$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ResizeNearestNeigborBackpropProgram = function ResizeNearestNeigborBackpropProgram(dyShape, inputShape, alignCorners) {
this.variableNames = ['dy'];
this.outputShape = [];
this.outputShape = inputShape;
var xHeight = inputShape[1],
xWidth = inputShape[2];
var yHeight = dyShape[1],
yWidth = dyShape[2]; // In the backwards pass, we want to find the pixels that were generated for
// each pixel in the input image the forward pass and add the corresponding
// coefficient from dy to the gradient (with some interpolation).
var effectiveXSize = [alignCorners && yHeight > 1 ? xHeight - 1 : xHeight, alignCorners && yWidth > 1 ? xWidth - 1 : xWidth];
var effectiveYSize = [alignCorners && yHeight > 1 ? yHeight - 1 : yHeight, alignCorners && yWidth > 1 ? yWidth - 1 : yWidth];
var heightScale = effectiveXSize[0] / effectiveYSize[0];
var widthScale = effectiveXSize[1] / effectiveYSize[1];
var invHeightScale = 1 / heightScale;
var invWidthScale = 1 / widthScale; // This defines the size of the window of values around a particular
// index in dy that we want to search for contributions to dx.
var winHeight = Math.ceil(invHeightScale) * 2 + 2;
var winWidth = Math.ceil(invWidthScale) * 2 + 2;
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function resizeNearestNeighborGrad$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var images = inputs.images,
dy = inputs.dy;
var alignCorners = attrs.alignCorners;
var program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
return backend.runWebGLProgram(program, [dy], dy.dtype);
}
var resizeNearestNeighborGradConfig$2 = {
kernelName: ResizeNearestNeighborGrad,
backendName: 'webgl',
kernelFunc: resizeNearestNeighborGrad$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReverseProgram = function ReverseProgram(xShape, axis) {
this.variableNames = ['x'];
var rank = xShape.length;
if (rank > 4) {
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
}
this.outputShape = xShape;
if (rank === 1) {
this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n ";
return;
}
var getInCoord = function getInCoord(i) {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return xShape[i] + " - coords[" + i + "] - 1";
}
return "coords[" + i + "]";
};
var inCoords = xShape.map(function (_, i) {
return getInCoord(i);
}).join(',');
var type = getCoordsDataType(rank);
this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n ";
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ReversePackedProgram = function ReversePackedProgram(xShape, axis) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
var rank = xShape.length;
if (rank > 4) {
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
}
this.outputShape = xShape;
var channels = getChannels('rc', rank);
var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1];
var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2];
var type = getCoordsDataType(rank);
if (rank === 1) {
this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n ";
} else {
this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n ";
}
function getR(channels) {
return getChannel(channels);
}
function getG(channels) {
channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
return getChannel(channels);
}
function getB(channels) {
channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
return getChannel(channels);
}
function getA(channels) {
channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
return getChannel(channels);
}
function getChannel(channels) {
var inCoordsArray = xShape.map(function (_, i) {
return getInCoord(i, channels);
});
var inCoords = inCoordsArray.join(',');
var innerDims = inCoordsArray.slice(-2).join(',');
return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))";
}
function getInCoord(i, channels1) {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return xShape[i] + " - " + channels1[i] + " - 1";
} else {
return "" + channels1[i];
}
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function reverse$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var dims = attrs.dims;
var xRank = x.shape.length;
var $dims = parseAxisParam(dims, x.shape);
if (xRank === 0) {
return identity$2({
inputs: {
x: x
},
backend: backend
});
}
var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, $dims) : new ReverseProgram(x.shape, $dims);
return backend.runWebGLProgram(program, [x], x.dtype);
}
var reverseConfig$1 = {
kernelName: Reverse,
backendName: 'webgl',
kernelFunc: reverse$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RotateProgram = function RotateProgram(imageShape, fillValue) {
this.variableNames = ['Image'];
this.outputShape = [];
this.customUniforms = [{
name: 'params',
type: 'vec4'
}];
var imageHeight = imageShape[1];
var imageWidth = imageShape[2];
this.outputShape = imageShape;
var fillSnippet = '';
if (typeof fillValue === 'number') {
fillSnippet = "float outputValue = " + fillValue.toFixed(2) + ";";
} else {
fillSnippet = "\n vec3 fill = vec3(" + fillValue.join(',') + ");\n float outputValue = fill[coords[3]];";
}
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n int y = coords[1];\n float coordXFloat = (float(x) - params[0]) * params[3] -\n (float(y) - params[1]) * params[2];\n float coordYFloat = (float(x) - params[0]) * params[2] +\n (float(y) - params[1]) * params[3];\n int coordX = int(round(coordXFloat + params[0]));\n int coordY = int(round(coordYFloat + params[1]));\n " + fillSnippet + "\n if(coordX >= 0 && coordX < " + imageWidth + " && coordY >= 0 && coordY < " + imageHeight + ") {\n outputValue = getImage(coords[0], coordY, coordX, coords[3]);\n }\n setOutput(outputValue);\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var rotateWithOffsetConfig$1 = {
kernelName: RotateWithOffset,
backendName: 'webgl',
kernelFunc: function kernelFunc(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var image = inputs.image;
var radians = attrs.radians,
fillValue = attrs.fillValue,
center = attrs.center;
var webglBackend = backend;
var program = new RotateProgram(image.shape, fillValue);
var _backend_util$getImag = getImageCenter(center, image.shape[1], image.shape[2]),
centerX = _backend_util$getImag[0],
centerY = _backend_util$getImag[1];
var customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
var output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
return output;
}
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n";
var round$3 = unaryKernelFunc$1({
opSnippet: ROUND
});
var roundConfig$1 = {
kernelName: Round,
backendName: 'webgl',
kernelFunc: round$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var RSQRT = "return inversesqrt(x);";
var rsqrt$2 = unaryKernelFunc$1({
opSnippet: RSQRT,
cpuKernelImpl: rsqrtImplCPU
});
var rsqrtConfig$1 = {
kernelName: Rsqrt,
backendName: 'webgl',
kernelFunc: rsqrt$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var ScatterProgram = function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) {
if (summingDupeIndex === void 0) {
summingDupeIndex = true;
}
this.variableNames = ['updates', 'indices', 'defaultValue'];
this.outputShape = shape;
var stridesType = getCoordsDataType(strides.length);
var dtype = getCoordsDataType(shape.length);
var indicesString = '';
if (indicesRank === 1) {
indicesString = 'i';
} else if (indicesRank === 2) {
indicesString = 'i, j';
}
var indicesSnippet = "getIndices(" + indicesString + ")";
var updatesString = '';
if (updatesRank === 1) {
updatesString = 'i';
} else if (updatesRank === 2) {
updatesString = 'i, coords[1]';
}
var updatesSnippet = "getUpdates(" + updatesString + ")";
var strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function scatterNd$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var indices = inputs.indices,
updates = inputs.updates;
var shape = attrs.shape;
var _backend_util$calcula = calculateShapes(updates, indices, shape),
sliceRank = _backend_util$calcula.sliceRank,
numUpdates = _backend_util$calcula.numUpdates,
sliceSize = _backend_util$calcula.sliceSize,
strides = _backend_util$calcula.strides,
outputSize = _backend_util$calcula.outputSize;
var flattenShape = [outputSize / sliceSize, sliceSize];
if (outputSize === 0) {
return backend.makeTensorInfo(shape, indices.dtype);
}
var flattenIndices = reshape$3({
inputs: {
x: indices
},
backend: backend,
attrs: {
shape: [numUpdates, sliceRank]
}
});
var flattenX = reshape$3({
inputs: {
x: updates
},
backend: backend,
attrs: {
shape: [numUpdates, sliceSize]
}
});
var defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0)
var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
var reshaped = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: shape
}
});
backend.disposeIntermediateTensorInfo(flattenIndices);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(res);
backend.disposeIntermediateTensorInfo(defaultValue);
return reshaped;
}
var scatterNdConfig$1 = {
kernelName: ScatterNd,
backendName: 'webgl',
kernelFunc: scatterNd$1
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SelectProgram = function SelectProgram(cRank, shape, rank) {
this.variableNames = ['c', 'a', 'b'];
this.outputShape = shape;
var cCoords;
var abCoords;
if (rank > 4) {
throw Error("Where for rank " + rank + " is not yet supported");
}
if (rank === 1) {
abCoords = "resRC";
cCoords = "resRC";
} else {
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
var cCoordVars = [];
var abCoordVars = [];
for (var i = 0; i < shape.length; i++) {
abCoordVars.push("" + currentCoords[i]);
if (i < cRank) {
cCoordVars.push("" + currentCoords[i]);
}
}
cCoords = cCoordVars.join();
abCoords = abCoordVars.join();
}
var dtype = getCoordsDataType(rank);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function select$2(args) {
var inputs = args.inputs,
backend = args.backend;
var condition = inputs.condition,
t = inputs.t,
e = inputs.e;
var program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
}
var selectConfig$1 = {
kernelName: Select,
backendName: 'webgl',
kernelFunc: select$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = " + SELU_SCALEALPHA + ";\n float scale = " + SELU_SCALE + ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n";
var selu$2 = unaryKernelFunc$1({
opSnippet: SELU
});
var seluConfig$1 = {
kernelName: Selu,
backendName: 'webgl',
kernelFunc: selu$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SIGMOID$2 = "return 1.0 / (1.0 + exp(-1.0 * x));";
var sigmoid$2 = unaryKernelFunc$1({
opSnippet: SIGMOID$2,
packedOpSnippet: SIGMOID$2,
cpuKernelImpl: sigmoidImplCPU
});
var sigmoidConfig$1 = {
kernelName: Sigmoid,
backendName: 'webgl',
kernelFunc: sigmoid$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n";
var sign$3 = unaryKernelFunc$1({
opSnippet: SIGN
});
var signConfig$1 = {
kernelName: Sign,
backendName: 'webgl',
kernelFunc: sign$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SIN = CHECK_NAN_SNIPPET_UNARY + "\n return sin(x);\n";
var sin$2 = unaryKernelFunc$1({
opSnippet: SIN
});
var sinConfig$1 = {
kernelName: Sin,
backendName: 'webgl',
kernelFunc: sin$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n";
var sinh$2 = unaryKernelFunc$1({
opSnippet: SINH
});
var sinhConfig$1 = {
kernelName: Sinh,
backendName: 'webgl',
kernelFunc: sinh$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n";
var softplus$2 = unaryKernelFunc$1({
opSnippet: SOFTPLUS
});
var softplusConfig$1 = {
kernelName: Softplus,
backendName: 'webgl',
kernelFunc: softplus$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var spaceToBatchND$2 = function spaceToBatchND(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var blockShape = attrs.blockShape,
paddings = attrs.paddings;
assert(x.shape.length <= 4, function () {
return 'spaceToBatchND for rank > 4 with a WebGL backend not ' + 'implemented yet';
});
var prod = blockShape.reduce(function (a, b) {
return a * b;
});
var completePaddings = [[0, 0]];
completePaddings.push.apply(completePaddings, paddings);
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
var toDispose = [];
var paddedX = padV2$1({
inputs: {
x: x
},
backend: backend,
attrs: {
paddings: completePaddings,
constantValue: 0
}
});
var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
var reshapedPaddedX = reshape$3({
inputs: {
x: paddedX
},
backend: backend,
attrs: {
shape: reshapedPaddedShape
}
});
var paddedXT = transpose$2({
inputs: {
x: reshapedPaddedX
},
backend: backend,
attrs: {
perm: permutedReshapedPaddedPermutation
}
});
var result = reshape$3({
inputs: {
x: paddedXT
},
backend: backend,
attrs: {
shape: flattenShape
}
});
toDispose.push(paddedX);
toDispose.push(reshapedPaddedX);
toDispose.push(paddedXT);
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
};
var spaceToBatchNDConfig$1 = {
kernelName: SpaceToBatchND,
backendName: 'webgl',
kernelFunc: spaceToBatchND$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseFillEmptyRows$2(args) {
var inputs = args.inputs,
backend = args.backend;
var indices = inputs.indices,
values = inputs.values,
denseShape = inputs.denseShape,
defaultValue = inputs.defaultValue;
if (denseShape.shape.length !== 1) {
throw new Error("Dense shape must be a vector, saw:\n " + denseShape.shape);
}
if (indices.shape.length !== 2) {
throw new Error("Indices must be a matrix, saw:\n " + indices.shape);
}
if (values.shape.length !== 1) {
throw new Error("Values must be a vector, saw:\n " + values.shape);
}
if (defaultValue.shape.length !== 0) {
throw new Error("Default value must be a scalar, saw:\n " + defaultValue.shape);
}
var $indices = backend.readSync(indices.dataId);
var $values = backend.readSync(values.dataId);
var $denseShape = backend.readSync(denseShape.dataId);
var $defaultValue = backend.readSync(defaultValue.dataId)[0];
var _sparseFillEmptyRowsI = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue),
outputIndices = _sparseFillEmptyRowsI[0],
outputIndicesShape = _sparseFillEmptyRowsI[1],
outputValues = _sparseFillEmptyRowsI[2],
emptyRowIndicator = _sparseFillEmptyRowsI[3],
reverseIndexMap = _sparseFillEmptyRowsI[4];
return [backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) {
return Number(value);
}))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap))];
}
var sparseFillEmptyRowsConfig$1 = {
kernelName: SparseFillEmptyRows,
backendName: 'webgl',
kernelFunc: sparseFillEmptyRows$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseReshape$2(args) {
var inputs = args.inputs,
backend = args.backend;
var inputIndices = inputs.inputIndices,
inputShape = inputs.inputShape,
newShape = inputs.newShape;
if (inputIndices.shape.length !== 2) {
throw new Error("Input indices should be a matrix but received shape " + inputIndices.shape);
}
if (inputShape.shape.length !== 1) {
throw new Error("Input shape should be a vector but received shape " + inputShape.shape);
}
if (newShape.shape.length !== 1) {
throw new Error("Target shape should be a vector but received shape " + newShape.shape);
}
var $inputShape = Array.from(backend.readSync(inputShape.dataId));
var $inputIndices = backend.readSync(inputIndices.dataId);
var targetShape = Array.from(backend.readSync(newShape.dataId));
var _sparseReshapeImplCPU = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape),
newIndices = _sparseReshapeImplCPU[0],
indicesShape = _sparseReshapeImplCPU[1],
outputShape = _sparseReshapeImplCPU[2];
return [backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape))];
}
var sparseReshapeConfig$1 = {
kernelName: SparseReshape,
backendName: 'webgl',
kernelFunc: sparseReshape$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseSegmentMean$2(args) {
var inputs = args.inputs,
backend = args.backend;
var data = inputs.data,
indices = inputs.indices,
segmentIds = inputs.segmentIds;
if (data.shape.length < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if (indices.shape.length !== 1) {
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
}
if (segmentIds.shape.length !== 1) {
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
}
var $data = backend.readSync(data.dataId);
var $indices = backend.readSync(indices.dataId);
var $segmentIds = backend.readSync(segmentIds.dataId);
var _sparseSegmentReducti = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true),
outputData = _sparseSegmentReducti[0],
outputDataShape = _sparseSegmentReducti[1];
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
var sparseSegmentMeanConfig$1 = {
kernelName: SparseSegmentMean,
backendName: 'webgl',
kernelFunc: sparseSegmentMean$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseSegmentSum$2(args) {
var inputs = args.inputs,
backend = args.backend;
var data = inputs.data,
indices = inputs.indices,
segmentIds = inputs.segmentIds;
if (data.shape.length < 1) {
throw new Error("Data should be at least 1 dimensional but received scalar");
}
if (indices.shape.length !== 1) {
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
}
if (segmentIds.shape.length !== 1) {
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
}
var $data = backend.readSync(data.dataId);
var $indices = backend.readSync(indices.dataId);
var $segmentIds = backend.readSync(segmentIds.dataId);
var _sparseSegmentReducti = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds),
outputData = _sparseSegmentReducti[0],
outputDataShape = _sparseSegmentReducti[1];
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
var sparseSegmentSumConfig$1 = {
kernelName: SparseSegmentSum,
backendName: 'webgl',
kernelFunc: sparseSegmentSum$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function sparseToDense$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var sparseIndices = inputs.sparseIndices,
sparseValues = inputs.sparseValues,
defaultValue = inputs.defaultValue;
var outputShape = attrs.outputShape;
var _backend_util$calcula = calculateShapes(sparseValues, sparseIndices, outputShape),
sliceRank = _backend_util$calcula.sliceRank,
numUpdates = _backend_util$calcula.numUpdates,
strides = _backend_util$calcula.strides,
outputSize = _backend_util$calcula.outputSize;
var sumDupeIndices = false;
var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
var res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
var reshaped = reshape$3({
inputs: {
x: res
},
backend: backend,
attrs: {
shape: outputShape
}
});
backend.disposeIntermediateTensorInfo(res);
return reshaped;
}
var sparseToDenseConfig$1 = {
kernelName: SparseToDense,
backendName: 'webgl',
kernelFunc: sparseToDense$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function splitV$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var numOrSizeSplits = attrs.numOrSizeSplits,
axis = attrs.axis;
var $axis = parseAxisParam(axis, x.shape)[0];
var splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
var xRank = x.shape.length;
var begin = new Array(xRank).fill(0);
var size = x.shape.slice();
return splitSizes.map(function (s) {
var sliceSize = [].concat(size);
sliceSize[$axis] = s;
var sliceT = slice$4({
inputs: {
x: x
},
backend: backend,
attrs: {
begin: begin,
size: sliceSize
}
});
begin[$axis] += s;
return sliceT;
});
}
var splitVConfig$1 = {
kernelName: SplitV,
backendName: 'webgl',
kernelFunc: splitV$1
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SQRT = "return sqrt(x);";
var sqrt$5 = unaryKernelFunc$1({
opSnippet: SQRT,
packedOpSnippet: SQRT,
cpuKernelImpl: sqrtImplCPU
});
var sqrtConfig$1 = {
kernelName: Sqrt,
backendName: 'webgl',
kernelFunc: sqrt$5
};
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SQUARE = "return x * x;";
var square$2 = unaryKernelFunc$1({
opSnippet: SQUARE
});
var squareConfig$1 = {
kernelName: Square,
backendName: 'webgl',
kernelFunc: square$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SQUARED_DIFFERENCE$1 = 'return (a - b) * (a - b);';
var squaredDifference$2 = binaryKernelFunc$1({
opSnippet: SQUARED_DIFFERENCE$1,
packedOpSnippet: SQUARED_DIFFERENCE$1
});
var squaredDifferenceConfig$1 = {
kernelName: SquaredDifference,
backendName: 'webgl',
kernelFunc: squaredDifference$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function step$2(_ref) {
var inputs = _ref.inputs,
attrs = _ref.attrs,
backend = _ref.backend;
var x = inputs.x;
var opSnippet = CHECK_NAN_SNIPPET + ("\n return x > 0.0 ? 1.0 : float(" + attrs.alpha + ");\n ");
var program = new UnaryOpProgram(x.shape, opSnippet);
return backend.runWebGLProgram(program, [x], x.dtype);
}
var stepConfig$1 = {
kernelName: Step,
backendName: 'webgl',
kernelFunc: step$2
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var StridedSliceProgram = function StridedSliceProgram(begin, strides, size) {
this.variableNames = ['x'];
this.outputShape = size;
var rank = size.length;
var inputDtype = getCoordsDataType(size.length);
var dtype = getCoordsDataType(size.length);
var newCoords = '';
if (rank === 1) {
newCoords = 'coords * strides + begin';
} else {
var outputAxis = 0;
newCoords = size.map(function (_, i) {
outputAxis++;
return size.length === 1 ? "coords * strides[" + i + "] + begin[" + i + "]" : "coords[" + (outputAxis - 1) + "] * strides[" + i + "] + begin[" + i + "]";
}).join(',');
}
this.userCode = "\n " + inputDtype + " begin = " + inputDtype + "(" + begin + ");\n " + inputDtype + " strides = " + inputDtype + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n setOutput(getX(" + newCoords + "));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stridedSlice$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var begin = attrs.begin,
end = attrs.end,
strides = attrs.strides,
beginMask = attrs.beginMask,
endMask = attrs.endMask,
ellipsisMask = attrs.ellipsisMask,
newAxisMask = attrs.newAxisMask,
shrinkAxisMask = attrs.shrinkAxisMask;
var _slice_util$sliceInfo = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask),
nonStrided = _slice_util$sliceInfo.nonStrided,
$begin = _slice_util$sliceInfo.$begin,
$strides = _slice_util$sliceInfo.$strides,
size = _slice_util$sliceInfo.size,
newShape = _slice_util$sliceInfo.newShape,
outShape = _slice_util$sliceInfo.outShape;
var $x = reshape$3({
inputs: {
x: x
},
backend: backend,
attrs: {
shape: newShape
}
});
var result;
if (nonStrided) {
var sliced = slice$4({
inputs: {
x: $x
},
backend: backend,
attrs: {
begin: $begin,
size: size
}
});
result = reshape$3({
inputs: {
x: sliced
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo(sliced);
} else if (outShape.some(function (axis) {
return axis === 0;
})) {
result = backend.makeTensorInfo(outShape, x.dtype, []);
} else {
var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([$x]);
if (shouldExecuteOnCPU) {
var xTexData = backend.texData.get($x.dataId);
var values = xTexData.values;
var xBuf = buffer($x.shape, $x.dtype, values);
var resultValues = stridedSliceImplCPU(outShape, xBuf, $strides, $begin);
result = backend.makeTensorInfo(outShape, $x.dtype, resultValues.values);
} else {
var program = new StridedSliceProgram($begin, $strides, outShape);
result = backend.runWebGLProgram(program, [$x], $x.dtype);
}
}
var resultReshaped = reshape$3({
inputs: {
x: result
},
backend: backend,
attrs: {
shape: outShape
}
});
backend.disposeIntermediateTensorInfo($x);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
var stridedSliceConfig$1 = {
kernelName: StridedSlice,
backendName: 'webgl',
kernelFunc: stridedSlice$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringNGrams$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var separator = attrs.separator,
nGramWidths = attrs.nGramWidths,
leftPad = attrs.leftPad,
rightPad = attrs.rightPad,
padWidth = attrs.padWidth,
preserveShortSequences = attrs.preserveShortSequences;
var data = inputs.data,
dataSplits = inputs.dataSplits;
var $data = backend.readSync(data.dataId);
var $dataSplits = backend.readSync(dataSplits.dataId);
var _stringNGramsImplCPU = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences),
nGrams = _stringNGramsImplCPU[0],
nGramsSplits = _stringNGramsImplCPU[1];
return [backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits)];
}
var stringNGramsConfig$1 = {
kernelName: StringNGrams,
backendName: 'webgl',
kernelFunc: stringNGrams$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringSplit$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var skipEmpty = attrs.skipEmpty;
var input = inputs.input,
delimiter = inputs.delimiter;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (input.shape.length !== 1) {
throw new Error("Input must be a vector, got shape: " + input.shape);
}
if (delimiter.shape.length !== 0) {
throw new Error("Delimiter must be a scalar, got shape: " + delimiter.shape);
}
var $input = backend.readSync(input.dataId);
var $delimiter = backend.readSync(delimiter.dataId)[0];
var _stringSplitImplCPU = stringSplitImplCPU($input, $delimiter, skipEmpty),
indices = _stringSplitImplCPU[0],
values = _stringSplitImplCPU[1],
shape = _stringSplitImplCPU[2];
var outputSize = values.length;
return [backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape))];
}
var stringSplitConfig$1 = {
kernelName: StringSplit,
backendName: 'webgl',
kernelFunc: stringSplit$2
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function stringToHashBucketFast$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var numBuckets = attrs.numBuckets;
var input = inputs.input;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (numBuckets <= 0) {
throw new Error("Number of buckets must be at least 1");
}
var $input = backend.readSync(input.dataId);
var output = stringToHashBucketFastImplCPU($input, numBuckets);
return backend.makeTensorInfo(input.shape, 'int32', output);
}
var stringToHashBucketFastConfig$1 = {
kernelName: StringToHashBucketFast,
backendName: 'webgl',
kernelFunc: stringToHashBucketFast$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TAN = "return tan(x);";
var tan$2 = unaryKernelFunc$1({
opSnippet: TAN
});
var tanConfig$1 = {
kernelName: Tan,
backendName: 'webgl',
kernelFunc: tan$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n";
var tanh$3 = unaryKernelFunc$1({
opSnippet: TANH
});
var tanhConfig$1 = {
kernelName: Tanh,
backendName: 'webgl',
kernelFunc: tanh$3
};
/**
* @license
* Copyright 2017 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TileProgram = function TileProgram(aShape, reps) {
this.variableNames = ['A'];
var outputShape = new Array(aShape.length);
for (var i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[i] * reps[i];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
var dtype = getCoordsDataType(this.rank);
var sourceCoords = getSourceCoords$2(aShape);
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
};
function getSourceCoords$2(aShape) {
var rank = aShape.length;
if (rank > 5) {
throw Error("Tile for rank " + rank + " is not yet supported");
}
if (rank === 1) {
return "imod(resRC, " + aShape[0] + ")";
}
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
var sourceCoords = [];
for (var i = 0; i < aShape.length; i++) {
sourceCoords.push("imod(" + currentCoords[i] + ", " + aShape[i] + ")");
}
return sourceCoords.join();
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function tile$3(params) {
var inputs = params.inputs,
backend = params.backend,
attrs = params.attrs;
var x = inputs.x;
var reps = attrs.reps; // tile gpu program cannot handle rank > 5 case.
if (x.dtype === 'string' || x.shape.length > 5) {
// Even thought string tensor is always on CPU, just to be consistent on how
// to access tensor data.
var data = backend.readSync(x.dataId);
var value = x.dtype === 'string' ? data.map(function (d) {
return decodeString(d);
}) : data;
var buf = buffer(x.shape, x.dtype, value);
var outBuf = tileImplCPU(buf, reps);
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
var program = new TileProgram(x.shape, reps);
var output = backend.runWebGLProgram(program, [x], x.dtype);
return output;
}
var tileConfig$1 = {
kernelName: Tile,
backendName: 'webgl',
kernelFunc: tile$3
};
// Based on Algorithm 2 of Bitonic Top K, ref:
// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
// The original algorithm is based on computing the top K only, however
// since for TFJS we require the indices of the top K values as well then the
// algorithm found here is a bit modified. Rather than producing the values
// at each step, the indices containing the top K are generated instead.
// The output values are not generated to reduce the number of outputs in the
// GPU, the values can easily be retrieved from the indices using a gather
// op.
var SwapProgram =
/**
* @param shape desired output shape (can be larger than input shape, output
* will be padded with -Infinity)
*/
function SwapProgram(shape) {
this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK.
// |firstPass|indicates if this is the first time swap is being used which
// means no indices input containing the top K is present yet.
// |inc| Swaps pairs of indices (0, inc), (1, inc + 1), (2, inc + 2) ...
this.customUniforms = [{
name: 'n',
type: 'int'
}, {
name: 'firstPass',
type: 'int'
}, {
name: 'negativeInf',
type: 'float'
}, {
name: 'dir',
type: 'int'
}, {
name: 'inc',
type: 'int'
}];
this.outputShape = shape;
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // We compare elements pair-wise within a group of size 2 * inc.\n // The comparing rule for each group alternates between ascending\n // and descending. Within each group, we compare each pair at\n // positions i and i+inc. To decide whether an element at position i\n // is x0 or x1, we mod it by 2 * inc, if the result is smaller than\n // inc, it is in the first half of the group, we denote it as x0,\n // otherwise we denote it as x1.\n // For example, as shown in the Bitonic top K paper referenced above,\n // Figure5(a) shows that element[1] is in the\n // second half of the group when group size is 2, but it is in the\n // first half of the group when group size is 4.\n\n bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;\n int i = isFirstInPair ? elemIdx : elemIdx - inc;\n\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));\n float x0 = i0 < n ? getX(batch, i0) : negativeInf;\n float x1 = i1 < n ? getX(batch, i1) : negativeInf;\n\n // Denotes which direction indices are in (ascending or descending).\n bool reverse = imod(elemIdx, 2 * dir) >= dir;\n bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);\n if (reverse == isGreater) { // Elements in opposite order of direction\n int iTemp = i0;\n i0 = i1;\n i1 = iTemp;\n }\n if (isFirstInPair) {\n setOutput(float(i0));\n } else {\n setOutput(float(i1));\n }\n }\n ";
};
var MergeProgram =
/**
* @param shape desired output shape (must be half of the input size)
*/
function MergeProgram(shape) {
this.variableNames = ['x', 'indices']; // |n| Size of the original input of TopK
// |firstPass| indicates if this is the first time swap is being used which
// means no indices input containing the top K is present yet.
// |k| Top k elements desired
this.customUniforms = [{
name: 'n',
type: 'int'
}, {
name: 'firstPass',
type: 'int'
}, {
name: 'k',
type: 'int'
}];
this.outputShape = shape;
this.userCode = "\n void main() {\n // Takes max of indices (0, k), (1, k + 1), (2, k + 2) ...\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int elemIdx = coords[1];\n\n // The output size is half of the previous size.\n // If the previous sequence is | | | | _ _ _ _ | | | | _ _ _ _ (k=4),\n // we only need to output the indices at positions |, the indices at\n // positions _ can be thrown away, see Figure5(b) After Phase 2\n // (Merge phase) in the Bitonic Top K paper referenced above.\n // For example, the paper shows we only need to output the orange bars.\n // The output sequence should look like this | | | | | | | |.\n // Because the sequence is halved, to map the output index back\n // to the previous sequence to find the corresponding value,\n // we need to double the index. When we double the index,\n // we basically interpolate a position, so 2i looks like\n // | _ | _ | _ | _ | _ | _ | _. We move the | to the first k position\n // of each 2k positions by - elemIdx % k. E.g. for output at\n // index 4,5,6,7, we want to get the corresponding element at\n // original index 8,9,10,11, for output at index 8,9,10,11,\n // we want to get the corresponding element at original index\n // 16,17,18,19, so on and so forth.\n\n int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));\n int i0 = firstPass == 1 ? i : int(getIndices(batch, i));\n int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));\n\n float x0 = getX(batch, i0);\n float x1 = i1 < n ? getX(batch, i1) : x0;\n\n setOutput(x0 >= x1 ? float(i0) : float(i1));\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
if (tensorInfo !== null) {
backend.disposeIntermediateTensorInfo(tensorInfo);
}
}
function roundUpToPow2(num) {
var pow2 = 1;
while (pow2 < num) {
pow2 *= 2;
}
return pow2;
} // Based on Algorithm 2 of Bitonic Top K, ref:
// https://anilshanbhag.in/static/papers/gputopk_sigmod18.pdf
function topK$1(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x;
var k = attrs.k,
sorted = attrs.sorted; // Empirically determined constant used to determine last dim threshold for
// handing off execution to the CPU.
var TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD'); // Empirically determined constant used to determine k threshold for handing
// off execution to the CPU.
var TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
var xShape = x.shape;
var lastDim = xShape[xShape.length - 1];
if (backend.shouldExecuteOnCPU([x]) || lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD || k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
var xVals = backend.readSync(x.dataId);
var _topKImplCPU = topKImplCPU(xVals, xShape, x.dtype, k, sorted),
allTopKVals = _topKImplCPU[0],
allTopKIndices = _topKImplCPU[1];
return [backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)];
}
if (k === 0) {
xShape[xShape.length - 1] = 0;
return [backend.makeTensorInfo(xShape, x.dtype, []), backend.makeTensorInfo(xShape, 'int32', [])];
}
if (lastDim === 1
/* firstPass */
) {
return [x, fill$2({
attrs: {
shape: xShape,
dtype: 'int32',
value: 0
},
backend: backend
})];
} // Eagerly unpack x input since it is passed in to all the shaders which
// require unpacked inputs.
var xtexData = backend.texData.get(x.dataId);
var xIsPacked = xtexData !== null && xtexData.isPacked;
var xUnPacked = xIsPacked ? backend.unpackTensor(x) : x; // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
var xSize = sizeFromShape(xShape);
var batch = xSize / lastDim;
var x2D = reshape$3({
inputs: {
x: xUnPacked
},
attrs: {
shape: [batch, lastDim]
},
backend: backend
});
if (xIsPacked) {
disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
}
var kPow2 = roundUpToPow2(k);
var lastDimPow2 = roundUpToPow2(lastDim); // Only the indices containing the top K are kept at every step to reduce
// number of outputs in the GPU algorithms, so once the final set of indices
// is computed then gather is used to grab the corresponding values
// from the original input.
var indices = null; // GPU algorithm always takes in an indices input but this input is not used
// on the first run of a GPU algorithm, therefore if indices is null we simply
// pass in x2D instead of it but the value will not actually be used
var getInputs = function getInputs() {
return indices === null ? [x2D, x2D] : [x2D, indices];
};
var runSwap = function runSwap(dir, inc, shape) {
var inputs = getInputs();
var program = new SwapProgram(shape);
var fistPass = indices === null ? 1 : 0;
var customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
var prevIndices = indices;
indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
}; // Step 1: local sort
for (var len = 1; len < kPow2; len *= 2) {
var dir = len * 2;
for (var inc = len; inc >= 1; inc /= 2) {
runSwap(dir, inc, [batch, lastDimPow2]);
}
} // Step 2: merge
for (var indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
var _inputs = getInputs();
var mergeProgram = new MergeProgram([batch, indicesSize / 2]);
var firstPass = indices === null ? 1 : 0;
var customValues = [[lastDim], [firstPass], [kPow2]];
var _prevIndices = indices;
indices = backend.runWebGLProgram(mergeProgram, _inputs, 'int32', customValues);
disposeIntermediateTensorInfoOrNull(backend, _prevIndices); // Step 3: rebuild
var _len = kPow2 / 2;
var _dir = _len * 2;
for (var _inc = _len; _inc >= 1; _inc /= 2) {
runSwap(_dir, _inc, indices.shape);
}
} // Keep only the requested top K results instead of kPow2
var prevIndices = indices;
indices = slice$4({
inputs: {
x: indices
},
backend: backend,
attrs: {
begin: 0,
size: [batch, k]
}
});
disposeIntermediateTensorInfoOrNull(backend, prevIndices); // Gather values on last dimension
var values = gatherV2$1({
inputs: {
x: x2D,
indices: indices
},
backend: backend,
attrs: {
axis: 1,
batchDims: 1
}
});
disposeIntermediateTensorInfoOrNull(backend, x2D); // Reshape back to the original input shape, except that the last
// dimension is k.
var newShape = xShape.slice(0, -1);
newShape.push(k);
prevIndices = indices;
indices = reshape$3({
inputs: {
x: indices
},
attrs: {
shape: newShape
},
backend: backend
});
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
var prevValues = values;
values = reshape$3({
inputs: {
x: values
},
attrs: {
shape: newShape
},
backend: backend
});
disposeIntermediateTensorInfoOrNull(backend, prevValues);
return [values, indices];
}
var topKConfig$1 = {
kernelName: TopK,
backendName: 'webgl',
kernelFunc: topK$1
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var TransformProgram = function TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
this.variableNames = ['Image', 'Transforms'];
this.outputShape = outShape;
var interpolationModeId = interpolation === 'nearest' ? 1 : 2;
var fillModeId;
switch (fillMode) {
case 'constant':
fillModeId = 1;
break;
case 'reflect':
fillModeId = 2;
break;
case 'wrap':
fillModeId = 3;
break;
case 'nearest':
fillModeId = 4;
break;
default:
fillModeId = 1;
break;
}
this.userCode = "\n float mapCoord(float outCoord, float len) {\n float inCoord = outCoord;\n if(" + fillModeId + " == 2) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n if (inCoord < sz2) {\n inCoord = sz2 * float(int(float(-inCoord / sz2))) +\n inCoord;\n }\n inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n inCoord -= sz2 * float(int(float(inCoord / sz2)));\n if (inCoord >= len) {\n inCoord = sz2 - inCoord - 1.0;\n }\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (" + fillModeId + " == 3) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord -= len * float(int(float(inCoord / sz)));\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (" + fillModeId + " == 4) {\n return clamp(outCoord, 0.0, len - 1.0);\n } else {\n return outCoord;\n }\n }\n\n float readWithFillValue(int batch, int coordY, int coordX,\n int channel) {\n float outputValue;\n if (0 <= coordY && coordY < " + imageHeight + " && 0 <= coordX && coordX < " + imageWidth + ") {\n outputValue = getImage(batch, coordY, coordX, channel);\n } else {\n outputValue = float(" + fillValue + ");\n }\n return outputValue;\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n float outputValue;\n int batch = coords[0];\n int x = coords[2];\n int y = coords[1];\n int channel = coords[3];\n float xf = float(x);\n float yf = float(y);\n float a1 = getTransforms(batch, 0);\n float a2 = getTransforms(batch, 1);\n float a3 = getTransforms(batch, 2);\n float b1 = getTransforms(batch, 3);\n float b2 = getTransforms(batch, 4);\n float b3 = getTransforms(batch, 5);\n float c1 = getTransforms(batch, 6);\n float c2 = getTransforms(batch, 7);\n float projection = c1 * xf + c2 * yf + 1.0;\n if (projection == 0.0) {\n outputValue = float(" + fillValue + ");\n } else {\n float inX = (a1 * xf + a2 * yf + a3) / projection;\n float inY = (b1 * xf + b2 * yf + b3) / projection;\n float mapX = mapCoord(inX, float(" + imageWidth + "));\n float mapY = mapCoord(inY, float(" + imageHeight + "));\n\n if (" + interpolationModeId + " == 1) {\n int coordY = int(round(mapY));\n int coordX = int(round(mapX));\n outputValue = readWithFillValue(batch, coordY, coordX,\n channel);\n } else {\n float yFloor = floor(mapY);\n float xFloor = floor(mapX);\n float yCeil = yFloor + 1.0;\n float xCeil = xFloor + 1.0;\n float valueYFloor = (xCeil - mapX) *\n readWithFillValue(batch, int(yFloor), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yFloor), int(xCeil), channel);\n float valueYCeil = (xCeil - mapX) *\n readWithFillValue(batch, int(yCeil), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yCeil), int(xCeil), channel);\n outputValue = (yCeil - mapY) * valueYFloor +\n (mapY - yFloor) * valueYCeil;\n }\n }\n setOutput(outputValue);\n }\n ";
};
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function transform$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var image = inputs.image,
transforms = inputs.transforms;
var interpolation = attrs.interpolation,
fillMode = attrs.fillMode,
fillValue = attrs.fillValue,
outputShape = attrs.outputShape;
var _image$shape = image.shape,
batch = _image$shape[0],
imageHeight = _image$shape[1],
imageWidth = _image$shape[2],
numChannels = _image$shape[3];
var _ref = outputShape != null ? outputShape : [imageHeight, imageWidth],
outHeight = _ref[0],
outWidth = _ref[1];
var outShape = [batch, outHeight, outWidth, numChannels];
var program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
return backend.runWebGLProgram(program, [image, transforms], 'float32');
}
var transformConfig$1 = {
kernelName: Transform,
backendName: 'webgl',
kernelFunc: transform$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the License);
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unique$3(args) {
var inputs = args.inputs,
attrs = args.attrs,
backend = args.backend;
var axis = attrs.axis;
var x = inputs.x;
assertNotComplex$1(x, 'unique'); // For now, always forward calculation to the CPU backend.
console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
var values = backend.readSync(x.dataId);
var _uniqueImplCPU = uniqueImplCPU(values, axis, x.shape, x.dtype),
outputValues = _uniqueImplCPU.outputValues,
outputShape = _uniqueImplCPU.outputShape,
indices = _uniqueImplCPU.indices;
return [backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices)];
}
var uniqueConfig$1 = {
kernelName: Unique,
backendName: 'webgl',
kernelFunc: unique$3
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unpack$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var value = inputs.value;
var axis = attrs.axis;
if (axis < 0) {
axis += value.shape.length;
}
var x = value;
var xRank = x.shape.length;
var num = value.shape[axis];
var outShape = new Array(xRank - 1);
var outIndex = 0;
for (var i = 0; i < xRank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
}
}
var toDispose = [];
var begin = new Array(xRank).fill(0);
var size = x.shape.slice();
size[axis] = 1;
var res = new Array(num);
for (var _i = 0; _i < res.length; _i++) {
begin[axis] = _i;
var sliced = slice$4({
inputs: {
x: x
},
backend: backend,
attrs: {
begin: begin,
size: size
}
});
var reshaped = reshape$3({
inputs: {
x: sliced
},
backend: backend,
attrs: {
shape: outShape
}
});
res[_i] = reshaped;
toDispose.push(sliced);
}
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return res;
}
var unpackConfig$1 = {
kernelName: Unpack,
backendName: 'webgl',
kernelFunc: unpack$2
};
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var SegmentOpProgram = function SegmentOpProgram(segOpInfo, segOpType) {
this.variableNames = ['x', 'segmentIds'];
var windowSize = segOpInfo.windowSize;
var batchSize = segOpInfo.batchSize;
var inSize = segOpInfo.inSize;
var numSegments = segOpInfo.numSegments;
var outSize = numSegments * Math.ceil(inSize / windowSize);
this.outputShape = [batchSize, outSize];
var initializationValue = '0.0';
var returnValue = "sumValue";
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
var windowSizeVec4Remainder = windowSize % 4;
var updateSnippet = "\n sumValue += dot(values, segFilter);\n ";
var checkValueOutOfBounds = '';
if (inSize % windowSize > 0) {
checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
}
var checkSegmentIdOutOfBounds = '';
if (inSize % windowSize > 0) {
checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n ";
}
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
function unsortedSegmentSum$2(args) {
var inputs = args.inputs,
backend = args.backend,
attrs = args.attrs;
var x = inputs.x,
segmentIds = inputs.segmentIds;
var numSegments = attrs.numSegments;
var xRank = x.shape.length;
var toDispose = [];
var axis = 0;
var permutation = getAxesPermutation([axis], xRank);
var permutedX = x;
if (permutation != null) {
permutedX = transpose$2({
inputs: {
x: x
},
backend: backend,
attrs: {
perm: permutation
}
});
toDispose.push(permutedX);
axis = getInnerMostAxes(1, xRank)[0];
}
var outShape = computeOutShape$2(permutedX.shape, axis, numSegments);
var inSize = sizeFromShape([permutedX.shape[axis]]);
var a2D = reshape$3({
inputs: {
x: permutedX
},
backend: backend,
attrs: {
shape: [-1, inSize]
}
});
toDispose.push(a2D);
var outputDType = sumOutType(x.dtype);
var segOpCompute = function segOpCompute(x, segOpType, segmentIds, dtype, numSegments) {
var batchSize = x.shape[0];
var inSize = x.shape[1];
var windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
var segOpInfo = {
windowSize: windowSize,
inSize: inSize,
batchSize: batchSize,
numSegments: numSegments
};
var program = new SegmentOpProgram(segOpInfo, segOpType);
var output = backend.compileAndRun(program, [x, segmentIds], dtype);
toDispose.push(output); // No need to run another GPGPU program.
if (output.shape[1] === numSegments) {
return output;
}
var rangeInfo = range$3({
backend: backend,
attrs: {
start: 0,
stop: numSegments,
step: 1,
dtype: 'float32'
}
});
var tileInfo = tile$3({
inputs: {
x: rangeInfo
},
backend: backend,
attrs: {
reps: [inSize / windowSize]
}
});
toDispose.push(rangeInfo);
toDispose.push(tileInfo);
var result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
return result;
};
var segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
var reshaped = reshape$3({
inputs: {
x: segOpResult
},
backend: backend,
attrs: {
shape: outShape
}
});
var result = reshaped;
if (permutation != null) {
toDispose.push(reshaped);
var perm = getUndoAxesPermutation(permutation);
result = transpose$2({
inputs: {
x: result
},
backend: backend,
attrs: {
perm: perm
}
});
}
toDispose.forEach(function (t) {
return backend.disposeIntermediateTensorInfo(t);
});
return result;
}
var unsortedSegmentSumConfig$1 = {
kernelName: UnsortedSegmentSum,
backendName: 'webgl',
kernelFunc: unsortedSegmentSum$2
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var kernelConfigs$1 = [LRNConfig, LRNGradConfig, _fusedMatMulConfig$1, absConfig$1, acosConfig$1, acoshConfig$1, addConfig$1, addNConfig$1, allConfig$1, anyConfig$1, argMaxConfig$1, argMinConfig$1, asinConfig$1, asinhConfig$1, atan2Config$1, atanConfig$1, atanhConfig$1, avgPool3DConfig$1, avgPoolConfig$1, avgPoolGrad3DConfig, avgPoolGradConfig$2, batchMatMulConfig$1, batchNormConfig$1, batchToSpaceNDConfig$1, bincountConfig$1, castConfig$1, ceilConfig$1, clipByValueConfig, complexAbsConfig$1, complexConfig$1, concatConfig$1, conv2DBackpropFilterConfig$1, conv2DBackpropInputConfig$1, conv2DConfig$1, conv3DBackpropFilterV2Config$1, conv3DBackpropInputConfig, conv3DConfig$1, cosConfig$1, coshConfig$1, cropAndResizeConfig$1, cumsumConfig$1, denseBincountConfig$1, depthToSpaceConfig$1, depthwiseConv2dNativeBackpropFilterConfig$1, depthwiseConv2dNativeBackpropInputConfig$1, depthwiseConv2dNativeConfig$1, diagConfig$1, dilation2DConfig, einsumConfig$1, eluConfig$1, eluGradConfig$2, equalConfig$1, erfConfig$1, expConfig$1, expandDimsConfig$1, expm1Config$1, fftConfig$1, fillConfig$1, flipLeftRightConfig$1, floorConfig$1, floorDivConfig$1, fromPixelsConfig, fusedConv2DConfig$1, fusedDepthwiseConv2DConfig$1, gatherNdConfig$1, gatherV2Config$1, greaterConfig$1, greaterEqualConfig$1, identityConfig$1, ifftConfig$1, imagConfig$1, isFiniteConfig$1, isInfConfig$1, isNaNConfig$1, leakyReluConfig$1, lessConfig$1, lessEqualConfig$1, linSpaceConfig$1, log1pConfig$1, logConfig$1, logicalAndConfig$1, logicalNotConfig$1, logicalOrConfig$1, maxConfig$1, maxPool3DConfig$1, maxPoolConfig$1, maxPoolGrad3DConfig, maxPoolGradConfig$2, maxPoolWithArgmaxConfig$1, maximumConfig$1, meanConfig$1, minConfig$1, minimumConfig$1, mirrorPadConfig$1, modConfig$1, multinomialConfig$1, multiplyConfig$1, negConfig$1, nonMaxSuppressionV3Config$1, nonMaxSuppressionV4Config$1, nonMaxSuppressionV5Config$1, notEqualConfig$1, oneHotConfig$1, onesLikeConfig$1, packConfig$1, padV2Config$1, powConfig$1, preluConfig$1, prodConfig$1, rangeConfig$1, realConfig$1, realDivConfig$1, reciprocalConfig$1, relu6Config$1, reluConfig$1, reshapeConfig$1, resizeBilinearConfig$1, resizeBilinearGradConfig$2, resizeNearestNeighborConfig$1, resizeNearestNeighborGradConfig$2, reverseConfig$1, rotateWithOffsetConfig$1, roundConfig$1, rsqrtConfig$1, scatterNdConfig$1, selectConfig$1, seluConfig$1, sigmoidConfig$1, signConfig$1, sinConfig$1, sinhConfig$1, sliceConfig$1, softmaxConfig$1, softplusConfig$1, spaceToBatchNDConfig$1, sparseFillEmptyRowsConfig$1, sparseReshapeConfig$1, sparseSegmentMeanConfig$1, sparseSegmentSumConfig$1, sparseToDenseConfig$1, splitVConfig$1, sqrtConfig$1, squareConfig$1, squaredDifferenceConfig$1, stepConfig$1, stridedSliceConfig$1, stringNGramsConfig$1, stringSplitConfig$1, stringToHashBucketFastConfig$1, subConfig$1, sumConfig$1, tanConfig$1, tanhConfig$1, tileConfig$1, topKConfig$1, transformConfig$1, transposeConfig$1, uniqueConfig$1, unpackConfig$1, unsortedSegmentSumConfig$1, zerosLikeConfig$1];
for (var _i$2 = 0, _kernelConfigs$1 = kernelConfigs$1; _i$2 < _kernelConfigs$1.length; _i$2++) {
var kernelConfig$1 = _kernelConfigs$1[_i$2];
registerKernel(kernelConfig$1);
}
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/** @license See the LICENSE file. */
// This code is auto-generated, do not modify this file!
var version$7 = '3.9.0';
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
var version$8 = {
'tfjs-core': version$1,
'tfjs-backend-cpu': version$5,
'tfjs-backend-webgl': version$6,
'tfjs-data': version$4,
'tfjs-layers': version$2,
'tfjs-converter': version$3,
'tfjs': version$7
};
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
exports.Abs = Abs;
exports.Acos = Acos;
exports.Acosh = Acosh;
exports.AdadeltaOptimizer = AdadeltaOptimizer;
exports.AdagradOptimizer = AdagradOptimizer;
exports.AdamOptimizer = AdamOptimizer;
exports.AdamaxOptimizer = AdamaxOptimizer;
exports.Add = Add;
exports.AddN = AddN;
exports.All = All;
exports.Any = Any;
exports.ArgMax = ArgMax;
exports.ArgMin = ArgMin;
exports.Asin = Asin;
exports.Asinh = Asinh;
exports.Atan = Atan;
exports.Atan2 = Atan2;
exports.Atanh = Atanh;
exports.AvgPool = AvgPool;
exports.AvgPool3D = AvgPool3D;
exports.AvgPool3DGrad = AvgPool3DGrad;
exports.AvgPoolGrad = AvgPoolGrad;
exports.BatchMatMul = BatchMatMul;
exports.BatchToSpaceND = BatchToSpaceND;
exports.Bincount = Bincount;
exports.BroadcastArgs = BroadcastArgs;
exports.BroadcastTo = BroadcastTo;
exports.Callback = Callback;
exports.CallbackList = CallbackList;
exports.Cast = Cast;
exports.Ceil = Ceil;
exports.ClipByValue = ClipByValue;
exports.Complex = Complex;
exports.ComplexAbs = ComplexAbs;
exports.Concat = Concat;
exports.Conv2D = Conv2D;
exports.Conv2DBackpropFilter = Conv2DBackpropFilter;
exports.Conv2DBackpropInput = Conv2DBackpropInput;
exports.Conv3D = Conv3D;
exports.Conv3DBackpropFilterV2 = Conv3DBackpropFilterV2;
exports.Conv3DBackpropInputV2 = Conv3DBackpropInputV2;
exports.Cos = Cos;
exports.Cosh = Cosh;
exports.CropAndResize = CropAndResize;
exports.Cumsum = Cumsum;
exports.CustomCallback = CustomCallback;
exports.DataStorage = DataStorage;
exports.DenseBincount = DenseBincount;
exports.DepthToSpace = DepthToSpace;
exports.DepthwiseConv2dNative = DepthwiseConv2dNative;
exports.DepthwiseConv2dNativeBackpropFilter = DepthwiseConv2dNativeBackpropFilter;
exports.DepthwiseConv2dNativeBackpropInput = DepthwiseConv2dNativeBackpropInput;
exports.Diag = Diag;
exports.Dilation2D = Dilation2D;
exports.Dilation2DBackpropFilter = Dilation2DBackpropFilter;
exports.Dilation2DBackpropInput = Dilation2DBackpropInput;
exports.EarlyStopping = EarlyStopping;
exports.Einsum = Einsum;
exports.Elu = Elu;
exports.EluGrad = EluGrad;
exports.Environment = Environment;
exports.Equal = Equal;
exports.Erf = Erf;
exports.Exp = Exp;
exports.ExpandDims = ExpandDims;
exports.Expm1 = Expm1;
exports.FFT = FFT;
exports.Fill = Fill;
exports.FlipLeftRight = FlipLeftRight;
exports.Floor = Floor;
exports.FloorDiv = FloorDiv;
exports.FromPixels = FromPixels;
exports.FusedBatchNorm = FusedBatchNorm;
exports.FusedConv2D = FusedConv2D;
exports.FusedDepthwiseConv2D = FusedDepthwiseConv2D;
exports.GatherNd = GatherNd;
exports.GatherV2 = GatherV2;
exports.GraphModel = GraphModel;
exports.Greater = Greater;
exports.GreaterEqual = GreaterEqual;
exports.History = History;
exports.IFFT = IFFT;
exports.Identity = Identity;
exports.Imag = Imag;
exports.InputSpec = InputSpec;
exports.IsFinite = IsFinite;
exports.IsInf = IsInf;
exports.IsNan = IsNan;
exports.KernelBackend = KernelBackend;
exports.LRN = LRN;
exports.LRNGrad = LRNGrad;
exports.LayerVariable = LayerVariable;
exports.LayersModel = LayersModel;
exports.LeakyRelu = LeakyRelu;
exports.Less = Less;
exports.LessEqual = LessEqual;
exports.LinSpace = LinSpace;
exports.Log = Log;
exports.Log1p = Log1p;
exports.LogSoftmax = LogSoftmax;
exports.LogicalAnd = LogicalAnd;
exports.LogicalNot = LogicalNot;
exports.LogicalOr = LogicalOr;
exports.Max = Max;
exports.MaxPool = MaxPool;
exports.MaxPool3D = MaxPool3D;
exports.MaxPool3DGrad = MaxPool3DGrad;
exports.MaxPoolGrad = MaxPoolGrad;
exports.MaxPoolWithArgmax = MaxPoolWithArgmax;
exports.Maximum = Maximum;
exports.Mean = Mean;
exports.Min = Min;
exports.Minimum = Minimum;
exports.MirrorPad = MirrorPad;
exports.Mod = Mod;
exports.MomentumOptimizer = MomentumOptimizer;
exports.Multinomial = Multinomial;
exports.Multiply = Multiply;
exports.Neg = Neg;
exports.NonMaxSuppressionV3 = NonMaxSuppressionV3;
exports.NonMaxSuppressionV4 = NonMaxSuppressionV4;
exports.NonMaxSuppressionV5 = NonMaxSuppressionV5;
exports.NotEqual = NotEqual;
exports.OP_SCOPE_SUFFIX = OP_SCOPE_SUFFIX;
exports.OneHot = OneHot;
exports.OnesLike = OnesLike;
exports.Optimizer = Optimizer;
exports.Pack = Pack;
exports.PadV2 = PadV2;
exports.Pool = Pool;
exports.Pow = Pow;
exports.Prelu = Prelu;
exports.Prod = Prod;
exports.RMSPropOptimizer = RMSPropOptimizer;
exports.RNN = RNN;
exports.Range = Range;
exports.Real = Real;
exports.RealDiv = RealDiv;
exports.Reciprocal = Reciprocal;
exports.Relu = Relu;
exports.Relu6 = Relu6;
exports.Reshape = Reshape;
exports.ResizeBilinear = ResizeBilinear;
exports.ResizeBilinearGrad = ResizeBilinearGrad;
exports.ResizeNearestNeighbor = ResizeNearestNeighbor;
exports.ResizeNearestNeighborGrad = ResizeNearestNeighborGrad;
exports.Reverse = Reverse;
exports.RotateWithOffset = RotateWithOffset;
exports.Round = Round;
exports.Rsqrt = Rsqrt;
exports.SGDOptimizer = SGDOptimizer;
exports.ScatterNd = ScatterNd;
exports.Select = Select;
exports.Selu = Selu;
exports.Sequential = Sequential;
exports.Sigmoid = Sigmoid;
exports.Sign = Sign;
exports.Sin = Sin;
exports.Sinh = Sinh;
exports.Slice = Slice;
exports.Softmax = Softmax;
exports.Softplus = Softplus;
exports.SpaceToBatchND = SpaceToBatchND;
exports.SparseFillEmptyRows = SparseFillEmptyRows;
exports.SparseReshape = SparseReshape;
exports.SparseSegmentMean = SparseSegmentMean;
exports.SparseSegmentSum = SparseSegmentSum;
exports.SparseToDense = SparseToDense;
exports.SplitV = SplitV;
exports.Sqrt = Sqrt;
exports.Square = Square;
exports.SquaredDifference = SquaredDifference;
exports.Step = Step;
exports.StridedSlice = StridedSlice;
exports.StringNGrams = StringNGrams;
exports.StringSplit = StringSplit;
exports.StringToHashBucketFast = StringToHashBucketFast;
exports.Sub = Sub;
exports.Sum = Sum;
exports.SymbolicTensor = SymbolicTensor;
exports.Tan = Tan;
exports.Tanh = Tanh;
exports.Tensor = Tensor;
exports.TensorBuffer = TensorBuffer;
exports.Tile = Tile;
exports.TopK = TopK;
exports.Transform = Transform;
exports.Transpose = Transpose;
exports.Unique = Unique;
exports.Unpack = Unpack;
exports.UnsortedSegmentSum = UnsortedSegmentSum;
exports.Variable = Variable;
exports.ZerosLike = ZerosLike;
exports._FusedMatMul = _FusedMatMul;
exports.abs = abs$8;
exports.acos = acos;
exports.acosh = acosh;
exports.add = add$1;
exports.addN = addN;
exports.all = all;
exports.any = any;
exports.argMax = argMax;
exports.argMin = argMin;
exports.asin = asin;
exports.asinh = asinh$1;
exports.atan = atan;
exports.atan2 = atan2;
exports.atanh = atanh;
exports.avgPool = avgPool;
exports.avgPool3d = avgPool3d;
exports.backend = backend;
exports.backend_util = backend_util;
exports.basicLSTMCell = basicLSTMCell;
exports.batchNorm = batchNorm;
exports.batchNorm2d = batchNorm2d;
exports.batchNorm3d = batchNorm3d;
exports.batchNorm4d = batchNorm4d;
exports.batchToSpaceND = batchToSpaceND;
exports.bincount = bincount;
exports.booleanMaskAsync = booleanMaskAsync;
exports.broadcastArgs = broadcastArgs;
exports.broadcastTo = broadcastTo;
exports.browser = browser;
exports.buffer = buffer;
exports.callbacks = callbacks;
exports.cast = cast;
exports.ceil = ceil$3;
exports.clipByValue = clipByValue;
exports.clone = clone;
exports.complex = complex;
exports.concat = concat;
exports.concat1d = concat1d;
exports.concat2d = concat2d;
exports.concat3d = concat3d;
exports.concat4d = concat4d;
exports.constraints = exports_constraints;
exports.conv1d = conv1d;
exports.conv2d = conv2d;
exports.conv2dTranspose = conv2dTranspose;
exports.conv3d = conv3d;
exports.conv3dTranspose = conv3dTranspose;
exports.copyRegisteredKernels = copyRegisteredKernels;
exports.cos = cos;
exports.cosh = cosh;
exports.cosineWindow = cosineWindow;
exports.cumsum = cumsum;
exports.customGrad = customGrad;
exports.data = index$1;
exports.denseBincount = denseBincount;
exports.deprecationWarn = deprecationWarn;
exports.depthToSpace = depthToSpace;
exports.depthwiseConv2d = depthwiseConv2d;
exports.deregisterOp = deregisterOp;
exports.device_util = device_util;
exports.diag = diag;
exports.dilation2d = dilation2d;
exports.disableDeprecationWarnings = disableDeprecationWarnings;
exports.dispose = dispose;
exports.disposeVariables = disposeVariables;
exports.div = div;
exports.divNoNan = divNoNan;
exports.dot = dot;
exports.dropout = dropout;
exports.einsum = einsum;
exports.elu = elu;
exports.enableDebugMode = enableDebugMode;
exports.enableProdMode = enableProdMode;
exports.enclosingPowerOfTwo = enclosingPowerOfTwo;
exports.engine = engine;
exports.env = env;
exports.equal = equal;
exports.erf = erf;
exports.exp = exp$3;
exports.expandDims = expandDims;
exports.expm1 = expm1;
exports.eye = eye;
exports.fft = fft;
exports.fill = fill;
exports.findBackend = findBackend;
exports.findBackendFactory = findBackendFactory;
exports.floor = floor$a;
exports.floorDiv = floorDiv;
exports.fused = fused_ops;
exports.gather = gather;
exports.gatherND = gatherND;
exports.gather_util = gather_nd_util;
exports.getBackend = getBackend;
exports.getGradient = getGradient;
exports.getKernel = getKernel;
exports.getKernelsForBackend = getKernelsForBackend;
exports.grad = grad;
exports.grads = grads;
exports.greater = greater;
exports.greaterEqual = greaterEqual;
exports.ifft = ifft;
exports.imag = imag;
exports.image = image;
exports.inTopKAsync = inTopKAsync;
exports.initializers = exports_initializers;
exports.input = input;
exports.io = io;
exports.irfft = irfft;
exports.isFinite = isFinite$1;
exports.isInf = isInf;
exports.isNaN = isNaN$1;
exports.keep = keep;
exports.kernel_impls = kernel_impls;
exports.layers = exports_layers;
exports.leakyRelu = leakyRelu;
exports.less = less;
exports.lessEqual = lessEqual;
exports.linalg = linalg;
exports.linspace = linspace;
exports.loadGraphModel = loadGraphModel;
exports.loadLayersModel = loadLayersModel;
exports.localResponseNormalization = localResponseNormalization;
exports.log = log$a;
exports.log1p = log1p;
exports.logSigmoid = logSigmoid;
exports.logSoftmax = logSoftmax;
exports.logSumExp = logSumExp;
exports.logicalAnd = logicalAnd;
exports.logicalNot = logicalNot;
exports.logicalOr = logicalOr;
exports.logicalXor = logicalXor;
exports.losses = losses;
exports.matMul = matMul;
exports.math = math;
exports.max = max$5;
exports.maxPool = maxPool;
exports.maxPool3d = maxPool3d;
exports.maxPoolWithArgmax = maxPoolWithArgmax;
exports.maximum = maximum;
exports.mean = mean;
exports.memory = memory;
exports.meshgrid = meshgrid;
exports.metrics = exports_metrics;
exports.min = min$9;
exports.minimum = minimum;
exports.mirrorPad = mirrorPad;
exports.mod = mod;
exports.model = model;
exports.models = exports_models;
exports.moments = moments;
exports.movingAverage = movingAverage;
exports.mul = mul;
exports.multiRNNCell = multiRNNCell;
exports.multinomial = multinomial;
exports.neg = neg;
exports.nextFrame = nextFrame;
exports.norm = norm;
exports.notEqual = notEqual;
exports.oneHot = oneHot;
exports.ones = ones$1;
exports.onesLike = onesLike;
exports.op = op;
exports.outerProduct = outerProduct;
exports.pad = pad;
exports.pad1d = pad1d;
exports.pad2d = pad2d;
exports.pad3d = pad3d;
exports.pad4d = pad4d;
exports.pool = pool;
exports.pow = pow$5;
exports.prelu = prelu;
exports.print = print;
exports.prod = prod;
exports.profile = profile;
exports.rand = rand;
exports.randomGamma = randomGamma;
exports.randomNormal = randomNormal;
exports.randomUniform = randomUniform;
exports.range = range;
exports.ready = ready;
exports.real = real;
exports.reciprocal = reciprocal;
exports.registerBackend = registerBackend;
exports.registerCallbackConstructor = registerCallbackConstructor;
exports.registerGradient = registerGradient;
exports.registerKernel = registerKernel;
exports.registerOp = registerOp;
exports.regularizers = exports_regularizers;
exports.relu = relu;
exports.relu6 = relu6;
exports.removeBackend = removeBackend;
exports.reshape = reshape;
exports.reverse = reverse;
exports.reverse1d = reverse1d;
exports.reverse2d = reverse2d;
exports.reverse3d = reverse3d;
exports.reverse4d = reverse4d;
exports.rfft = rfft;
exports.round = round$1;
exports.rsqrt = rsqrt;
exports.scalar = scalar;
exports.scatterND = scatterND;
exports.scatter_util = scatter_nd_util;
exports.selu = selu;
exports.separableConv2d = separableConv2d;
exports.sequential = sequential;
exports.serialization = serialization;
exports.setBackend = setBackend;
exports.setPlatform = setPlatform;
exports.setdiff1dAsync = setdiff1dAsync;
exports.sigmoid = sigmoid;
exports.sign = sign;
exports.signal = signal;
exports.sin = sin;
exports.sinh = sinh;
exports.slice = slice$2;
exports.slice1d = slice1d;
exports.slice2d = slice2d;
exports.slice3d = slice3d;
exports.slice4d = slice4d;
exports.slice_util = slice_util;
exports.softmax = softmax;
exports.softplus = softplus;
exports.spaceToBatchND = spaceToBatchND;
exports.sparse = sparse;
exports.sparseToDense = sparseToDense;
exports.spectral = spectral;
exports.split = split$1;
exports.sqrt = sqrt$3;
exports.square = square;
exports.squaredDifference = squaredDifference;
exports.squeeze = squeeze;
exports.stack = stack;
exports.step = step;
exports.stridedSlice = stridedSlice;
exports.string = string;
exports.sub = sub;
exports.sum = sum$1;
exports.sumOutType = sumOutType;
exports.tan = tan;
exports.tanh = tanh$1;
exports.tensor = tensor;
exports.tensor1d = tensor1d;
exports.tensor2d = tensor2d;
exports.tensor3d = tensor3d;
exports.tensor4d = tensor4d;
exports.tensor5d = tensor5d;
exports.tensor6d = tensor6d;
exports.tensor_util = tensor_util;
exports.test_util = test_util;
exports.tidy = tidy;
exports.tile = tile;
exports.time = time;
exports.topk = topk;
exports.train = train;
exports.transpose = transpose;
exports.truncatedNormal = truncatedNormal;
exports.unique = unique;
exports.unregisterGradient = unregisterGradient;
exports.unregisterKernel = unregisterKernel;
exports.unsortedSegmentSum = unsortedSegmentSum;
exports.unstack = unstack;
exports.upcastType = upcastType;
exports.util = util;
exports.valueAndGrad = valueAndGrad;
exports.valueAndGrads = valueAndGrads;
exports.variable = variable;
exports.variableGrads = variableGrads;
exports.version = version$8;
exports.version_converter = version$3;
exports.version_core = version$1;
exports.version_layers = version$2;
exports.where = where;
exports.whereAsync = whereAsync;
exports.zeros = zeros;
exports.zerosLike = zerosLike;
Object.defineProperty(exports, '__esModule', { value: true });
})));
//# sourceMappingURL=tf.js.map