diff --git a/server/bert_serving/server/__init__.py b/server/bert_serving/server/__init__.py index aab902d95..bab06b37e 100644 --- a/server/bert_serving/server/__init__.py +++ b/server/bert_serving/server/__init__.py @@ -491,8 +491,8 @@ def get_estimator(self, tf): from tensorflow.python.estimator.model_fn import EstimatorSpec def model_fn(features, labels, mode, params): - with tf.gfile.GFile(self.graph_path, 'rb') as f: - graph_def = tf.GraphDef() + with tf.io.gfile.GFile(self.graph_path, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) input_names = ['input_ids', 'input_mask', 'input_type_ids'] @@ -506,7 +506,7 @@ def model_fn(features, labels, mode, params): 'encodes': output[0] }) - config = tf.ConfigProto(device_count={'GPU': 0 if self.device_id < 0 else 1}) + config = tf.compat.v1.ConfigProto(device_count={'GPU': 0 if self.device_id < 0 else 1}) config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_fraction config.log_device_placement = False diff --git a/server/bert_serving/server/bert/modeling.py b/server/bert_serving/server/bert/modeling.py index 7813e2d62..bb365d65d 100644 --- a/server/bert_serving/server/bert/modeling.py +++ b/server/bert_serving/server/bert/modeling.py @@ -88,7 +88,7 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" - with tf.gfile.GFile(json_file, "r") as reader: + with tf.io.gfile.GFile(json_file, "r") as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -119,7 +119,7 @@ class BertModel(object): model = modeling.BertModel(config=config, is_training=True, input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) - label_embeddings = tf.get_variable(...) + label_embeddings = tf.compat.v1.get_variable(...) pooled_output = model.get_pooled_output() logits = tf.matmul(pooled_output, label_embeddings) ... @@ -169,8 +169,8 @@ def __init__(self, if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) - with tf.variable_scope(scope, default_name="bert"): - with tf.variable_scope("embeddings"): + with tf.compat.v1.variable_scope(scope, default_name="bert"): + with tf.compat.v1.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.embedding_output, self.embedding_table) = embedding_lookup( input_ids=input_ids, @@ -194,7 +194,7 @@ def __init__(self, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) - with tf.variable_scope("encoder"): + with tf.compat.v1.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. @@ -222,12 +222,12 @@ def __init__(self, # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. - with tf.variable_scope("pooler"): + with tf.compat.v1.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) # https://github.com/google-research/bert/issues/43#issuecomment-435980269 - self.pooled_output = tf.layers.dense( + self.pooled_output = tf.compat.v1.layers.dense( first_token_tensor, config.hidden_size, activation=tf.tanh, @@ -275,7 +275,7 @@ def gelu(input_tensor): Returns: `input_tensor` with the GELU activation applied. """ - cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) + cdf = 0.5 * (1.0 + tf.math.erf(input_tensor / tf.sqrt(2.0))) return input_tensor * cdf @@ -363,8 +363,7 @@ def dropout(input_tensor, dropout_prob): def layer_norm(input_tensor, name=None): """Run layer normalization on the last dimension of the tensor.""" - return tf.contrib.layers.layer_norm( - inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) + return tf.keras.layers.LayerNormalization(axis=-1)(input_tensor) def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): @@ -376,7 +375,7 @@ def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): def create_initializer(initializer_range=0.02): """Creates a `truncated_normal_initializer` with the given range.""" - return tf.truncated_normal_initializer(stddev=initializer_range) + return tf.compat.v1.truncated_normal_initializer(stddev=initializer_range) def embedding_lookup(input_ids, @@ -409,7 +408,7 @@ def embedding_lookup(input_ids, if input_ids.shape.ndims == 2: input_ids = tf.expand_dims(input_ids, axis=[-1]) - embedding_table = tf.get_variable( + embedding_table = tf.compat.v1.get_variable( name=word_embedding_name, shape=[vocab_size, embedding_size], initializer=create_initializer(initializer_range)) @@ -482,7 +481,7 @@ def embedding_postprocessor(input_tensor, if token_type_ids is None: raise ValueError("`token_type_ids` must be specified if" "`use_token_type` is True.") - token_type_table = tf.get_variable( + token_type_table = tf.compat.v1.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, width], initializer=create_initializer(initializer_range)) @@ -496,7 +495,7 @@ def embedding_postprocessor(input_tensor, output += token_type_embeddings if use_position_embeddings: - full_position_embeddings = tf.get_variable( + full_position_embeddings = tf.compat.v1.get_variable( name=position_embedding_name, shape=[max_position_embeddings, width], initializer=create_initializer(initializer_range)) @@ -675,7 +674,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, to_tensor_2d = reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] - query_layer = tf.layers.dense( + query_layer = tf.compat.v1.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, @@ -683,7 +682,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, kernel_initializer=create_initializer(initializer_range)) # `key_layer` = [B*T, N*H] - key_layer = tf.layers.dense( + key_layer = tf.compat.v1.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, @@ -691,7 +690,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, kernel_initializer=create_initializer(initializer_range)) # `value_layer` = [B*T, N*H] - value_layer = tf.layers.dense( + value_layer = tf.compat.v1.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, @@ -836,12 +835,12 @@ def transformer_model(input_tensor, all_layer_outputs = [] for layer_idx in range(num_hidden_layers): - with tf.variable_scope("layer_%d" % layer_idx): + with tf.compat.v1.variable_scope("layer_%d" % layer_idx): layer_input = prev_output - with tf.variable_scope("attention"): + with tf.compat.v1.variable_scope("attention"): attention_heads = [] - with tf.variable_scope("self"): + with tf.compat.v1.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, to_tensor=layer_input, @@ -866,8 +865,8 @@ def transformer_model(input_tensor, # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. - with tf.variable_scope("output"): - attention_output = tf.layers.dense( + with tf.compat.v1.variable_scope("output"): + attention_output = tf.compat.v1.layers.dense( attention_output, hidden_size, kernel_initializer=create_initializer(initializer_range)) @@ -875,16 +874,16 @@ def transformer_model(input_tensor, attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. - with tf.variable_scope("intermediate"): - intermediate_output = tf.layers.dense( + with tf.compat.v1.variable_scope("intermediate"): + intermediate_output = tf.compat.v1.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=create_initializer(initializer_range)) # Down-project back to `hidden_size` then add the residual. - with tf.variable_scope("output"): - layer_output = tf.layers.dense( + with tf.compat.v1.variable_scope("output"): + layer_output = tf.compat.v1.layers.dense( intermediate_output, hidden_size, kernel_initializer=create_initializer(initializer_range)) @@ -991,7 +990,7 @@ def assert_rank(tensor, expected_rank, name=None): actual_rank = tensor.shape.ndims if actual_rank not in expected_rank_dict: - scope_name = tf.get_variable_scope().name + scope_name = tf.compat.v1.get_variable_scope().name raise ValueError( "For the tensor `%s` in scope `%s`, the actual rank " "`%d` (shape = %s) is not equal to the expected rank `%s`" % diff --git a/server/bert_serving/server/bert/optimization.py b/server/bert_serving/server/bert/optimization.py index 6512336a7..d58da9ea9 100644 --- a/server/bert_serving/server/bert/optimization.py +++ b/server/bert_serving/server/bert/optimization.py @@ -64,9 +64,9 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) if use_tpu: - optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) - tvars = tf.trainable_variables() + tvars = tf.compat.v1.trainable_variables() grads = tf.gradients(loss, tvars) # This is how the model was pre-trained. @@ -110,13 +110,13 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): param_name = self._get_variable_name(param.name) - m = tf.get_variable( + m = tf.compat.v1.get_variable( name=param_name + "/adam_m", shape=param.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) - v = tf.get_variable( + v = tf.compat.v1.get_variable( name=param_name + "/adam_v", shape=param.shape.as_list(), dtype=tf.float32, diff --git a/server/bert_serving/server/bert/tokenization.py b/server/bert_serving/server/bert/tokenization.py index 2fb527f14..d0ca50c82 100644 --- a/server/bert_serving/server/bert/tokenization.py +++ b/server/bert_serving/server/bert/tokenization.py @@ -72,7 +72,7 @@ def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() index = 0 - with tf.gfile.GFile(vocab_file, "r") as reader: + with tf.io.gfile.GFile(vocab_file, "r") as reader: while True: token = convert_to_unicode(reader.readline()) if not token: diff --git a/server/bert_serving/server/cli/__init__.py b/server/bert_serving/server/cli/__init__.py index 69f123d71..9105a268b 100644 --- a/server/bert_serving/server/cli/__init__.py +++ b/server/bert_serving/server/cli/__init__.py @@ -1,3 +1,7 @@ +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + + def main(): from bert_serving.server import BertServer from bert_serving.server.helper import get_run_args diff --git a/server/bert_serving/server/graph.py b/server/bert_serving/server/graph.py index 3d68bd5f3..29432041e 100644 --- a/server/bert_serving/server/graph.py +++ b/server/bert_serving/server/graph.py @@ -44,7 +44,7 @@ def optimize_graph(args, logger=None): tf = import_tf(verbose=args.verbose) from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference - config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) + config = tf.compat.v1.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True) config_fp = os.path.join(args.model_dir, args.config_name) init_checkpoint = os.path.join(args.tuned_model_dir or args.model_dir, args.ckpt_name) @@ -56,16 +56,16 @@ def optimize_graph(args, logger=None): logger.info( 'checkpoint%s: %s' % ( ' (override by the fine-tuned model)' if args.tuned_model_dir else '', init_checkpoint)) - with tf.gfile.GFile(config_fp, 'r') as f: + with tf.io.gfile.GFile(config_fp, 'r') as f: bert_config = modeling.BertConfig.from_dict(json.load(f)) logger.info('build graph...') # input placeholders, not sure if they are friendly to XLA - input_ids = tf.placeholder(tf.int32, (None, None), 'input_ids') - input_mask = tf.placeholder(tf.int32, (None, None), 'input_mask') - input_type_ids = tf.placeholder(tf.int32, (None, None), 'input_type_ids') + input_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_ids') + input_mask = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_mask') + input_type_ids = tf.compat.v1.placeholder(tf.int32, (None, None), 'input_type_ids') - jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress + jit_scope = tf.xla.experimental.jit_scope if args.xla else contextlib.suppress with jit_scope(): input_tensors = [input_ids, input_mask, input_type_ids] @@ -81,28 +81,28 @@ def optimize_graph(args, logger=None): if args.pooling_strategy == PoolingStrategy.CLASSIFICATION: hidden_size = model.pooled_output.shape[-1].value - output_weights = tf.get_variable( + output_weights = tf.compat.v1.get_variable( 'output_weights', [args.num_labels, hidden_size], - initializer=tf.truncated_normal_initializer(stddev=0.02)) + initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)) - output_bias = tf.get_variable( + output_bias = tf.compat.v1.get_variable( 'output_bias', [args.num_labels], initializer=tf.zeros_initializer()) if args.pooling_strategy == PoolingStrategy.REGRESSION: hidden_size = model.pooled_output.shape[-1].value - output_weights = tf.get_variable( + output_weights = tf.compat.v1.get_variable( 'output_weights', [1, hidden_size], - initializer=tf.truncated_normal_initializer(stddev=0.02)) + initializer=tf.compat.v1.truncated_normal_initializer(stddev=0.02)) - output_bias = tf.get_variable( + output_bias = tf.compat.v1.get_variable( 'output_bias', [1], initializer=tf.zeros_initializer()) - tvars = tf.trainable_variables() + tvars = tf.compat.v1.trainable_variables() (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) - tf.train.init_from_checkpoint(init_checkpoint, assignment_map) + tf.compat.v1.train.init_from_checkpoint(init_checkpoint, assignment_map) minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30 mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) @@ -110,7 +110,7 @@ def optimize_graph(args, logger=None): masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) - with tf.variable_scope("pooling"): + with tf.compat.v1.variable_scope("pooling"): if len(args.pooling_layer) == 1: encoder_layer = model.all_encoder_layers[args.pooling_layer[0]] else: @@ -156,12 +156,12 @@ def optimize_graph(args, logger=None): pooled = tf.identity(pooled, 'final_encodes') output_tensors = [pooled] - tmp_g = tf.get_default_graph().as_graph_def() + tmp_g = tf.compat.v1.get_default_graph().as_graph_def() - with tf.Session(config=config) as sess: + with tf.compat.v1.Session(config=config) as sess: logger.info('load parameters from checkpoint...') - sess.run(tf.global_variables_initializer()) + sess.run(tf.compat.v1.global_variables_initializer()) dtypes = [n.dtype for n in input_tensors] logger.info('optimize...') tmp_g = optimize_for_inference( @@ -177,7 +177,7 @@ def optimize_graph(args, logger=None): tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.graph_tmp_dir).name logger.info('write graph to a tmp file: %s' % tmp_file) - with tf.gfile.GFile(tmp_file, 'wb') as f: + with tf.io.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) return tmp_file, bert_config except Exception: