Tensorflow no puede obtener `image.shape` del método en` dataset.map (mapFn) `


10

Estoy tratando de hacer el tensorflowequivalente de torch.transforms.Resize(TRAIN_IMAGE_SIZE), que cambia el tamaño de la dimensión de imagen más pequeñaTRAIN_IMAGE_SIZE . Algo como esto

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

La respuesta simple está aquí: Tensorflow: recorta la región cuadrada central más grande de la imagen

Pero cuando uso el método con tf.data.Dataset.map(transforms), me sale shape=(None,None,3)de adentro largest_sq_crop(image). El método funciona bien cuando lo llamo normalmente.


1
Creo que el problema tiene que ver con el hecho de que EagerTensorsno están disponibles dentro, Dataset.map()por lo que se desconoce la forma. ¿hay alguna solución?
Michael

¿Puedes incluir la definición de largest_sq_crop?
jakub

Respuestas:


1

Encontré la respuesta. Tenía que ver con el hecho de que mi método de cambio de tamaño funcionaba bien con una ejecución ansiosa, por ejemplo, tf.executing_eagerly()==Truepero fallaba cuando se usaba dentro dataset.map(). Al parecer, en ese entorno de ejecución, tf.executing_eagerly()==False.

Mi error fue en la forma en que estaba desempacando la forma de la imagen para obtener dimensiones para escalar. La ejecución del gráfico de Tensorflow no parece admitir el acceso a la tensor.shapetupla.

  # wrong
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # also wrong
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # but this works!!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

Estaba usando dimensiones de forma aguas abajo en mi dataset.map()función y arrojó la siguiente excepción porque estaba obteniendo en Nonelugar de un valor.

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.

Cuando cambié a desempaquetar manualmente la forma tf.shape(), todo funcionó bien.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.