diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py index db4366390e1ab..bc6480032a5c3 100644 --- a/flink-python/pyflink/datastream/functions.py +++ b/flink-python/pyflink/datastream/functions.py @@ -19,7 +19,7 @@ from enum import Enum from py4j.java_gateway import JavaObject -from typing import Union, Any, Generic, TypeVar, Iterable, List, Callable, Optional +from typing import Union, Any, Generic, TypeVar, Iterable, Iterator, List, Callable, Optional from pyflink.datastream.state import ValueState, ValueStateDescriptor, ListStateDescriptor, \ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \ @@ -203,14 +203,14 @@ class Function(ABC): """ The base class for all user-defined functions. """ - def open(self, runtime_context: RuntimeContext): + def open(self, runtime_context: RuntimeContext) -> None: pass - def close(self): + def close(self) -> None: pass -class MapFunction(Function): +class MapFunction(Function, Generic[IN, OUT]): """ Base class for Map functions. Map functions take elements and transform them, element wise. A Map function always produces a single result element for each input element. Typical @@ -225,7 +225,7 @@ class MapFunction(Function): """ @abstractmethod - def map(self, value): + def map(self, value: IN) -> OUT: """ The mapping method. Takes an element from the input data and transforms it into exactly one element. @@ -236,7 +236,7 @@ def map(self, value): pass -class CoMapFunction(Function): +class CoMapFunction(Function, Generic[IN, OUT]): """ A CoMapFunction implements a map() transformation over two connected streams. @@ -252,7 +252,7 @@ class CoMapFunction(Function): """ @abstractmethod - def map1(self, value): + def map1(self, value: IN) -> OUT: """ This method is called for each element in the first of the connected streams. @@ -262,7 +262,7 @@ def map1(self, value): pass @abstractmethod - def map2(self, value): + def map2(self, value: IN) -> OUT: """ This method is called for each element in the second of the connected streams. @@ -278,7 +278,7 @@ class FlatMapFunction(Function): one, or more elements. Typical applications can be splitting elements, or unnesting lists and arrays. Operations that produce multiple strictly one result element per input element can also use the MapFunction. - The basic syntax for using a MapFUnction is as follows: + The basic syntax for using a MapFunction is as follows: :: >>> ds = ... @@ -286,17 +286,17 @@ class FlatMapFunction(Function): """ @abstractmethod - def flat_map(self, value): + def flat_map(self, value: IN) -> Iterator[OUT]: """ - The core mthod of the FlatMapFunction. Takes an element from the input data and transforms + The core method of the FlatMapFunction. Takes an element from the input data and transforms it into zero, one, or more elements. A basic implementation of flat map is as follows: :: >>> class MyFlatMapFunction(FlatMapFunction): - >>> def flat_map(self, value): - >>> for i in range(value): - >>> yield i + ... def flat_map(self, value: IN) -> Iterator[OUT]: + ... for i in range(value): + ... yield i :param value: The input value. :return: A generator @@ -336,7 +336,7 @@ class CoFlatMapFunction(Function): """ @abstractmethod - def flat_map1(self, value): + def flat_map1(self, value: IN) -> Iterator[OUT]: """ This method is called for each element in the first of the connected streams. @@ -346,7 +346,7 @@ def flat_map1(self, value): pass @abstractmethod - def flat_map2(self, value): + def flat_map2(self, value: IN) -> Iterator[OUT]: """ This method is called for each element in the second of the connected streams. @@ -371,7 +371,7 @@ class ReduceFunction(Function): """ @abstractmethod - def reduce(self, value1, value2): + def reduce(self, value1: IN, value2: IN) -> IN: """ The core method of ReduceFunction, combining two values into one value of the same type. The reduce function is consecutively applied to all values of a group until only a single @@ -461,7 +461,7 @@ def merge(self, acc_a, acc_b): pass -class KeySelector(Function): +class KeySelector(Function, Generic[IN, KEY]): """ The KeySelector allows to use deterministic objects for operations such as reduce, reduceGroup, join coGroup, etc. If invoked multiple times on the same object, the returned key must be the @@ -469,7 +469,7 @@ class KeySelector(Function): """ @abstractmethod - def get_key(self, value): + def get_key(self, value: IN) -> KEY: """ User-defined function that deterministically extracts the key from an object. @@ -505,7 +505,7 @@ class FilterFunction(Function): """ @abstractmethod - def filter(self, value): + def filter(self, value: IN) -> bool: """ The filter function that evaluates the predicate. @@ -655,7 +655,7 @@ def timestamp(self) -> int: pass @abstractmethod - def process_element(self, value, ctx: 'ProcessFunction.Context'): + def process_element(self, value: IN, ctx: 'ProcessFunction.Context') -> Iterator[OUT]: """ Process one element from the input stream. @@ -716,7 +716,7 @@ def time_domain(self) -> TimeDomain: pass @abstractmethod - def process_element(self, value, ctx: 'KeyedProcessFunction.Context'): + def process_element(self, value: IN, ctx: 'KeyedProcessFunction.Context') -> Iterator[OUT]: """ Process one element from the input stream. @@ -780,7 +780,7 @@ def timestamp(self) -> int: pass @abstractmethod - def process_element1(self, value, ctx: 'CoProcessFunction.Context'): + def process_element1(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]: """ This method is called for each element in the first of the connected streams. @@ -795,7 +795,7 @@ def process_element1(self, value, ctx: 'CoProcessFunction.Context'): pass @abstractmethod - def process_element2(self, value, ctx: 'CoProcessFunction.Context'): + def process_element2(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]: """ This method is called for each element in the second of the connected streams. diff --git a/flink-python/pyflink/datastream/tests/test_typing.py b/flink-python/pyflink/datastream/tests/test_typing.py new file mode 100644 index 0000000000000..316fedbfeb2ab --- /dev/null +++ b/flink-python/pyflink/datastream/tests/test_typing.py @@ -0,0 +1,42 @@ +import unittest +from pyflink.datastream.functions import ( + MapFunction, FlatMapFunction, FilterFunction, KeySelector +) +from typing import Iterator + +class TypedMapFunction(MapFunction[int, str]): + def map(self, value: int) -> str: + return str(value) + +class TypedFlatMapFunction(FlatMapFunction[int, str]): + def flat_map(self, value: int) -> Iterator[str]: + yield str(value) + +class TypedFilterFunction(FilterFunction[int]): + def filter(self, value: int) -> bool: + return value > 0 + +class TypedKeySelector(KeySelector[int, str]): + def get_key(self, value: int) -> str: + return str(value) + +class TestGenericTyping(unittest.TestCase): + def test_map_function_generics(self): + f = TypedMapFunction() + self.assertEqual(f.map(1), "1") + + def test_flat_map_function_generics(self): + f = TypedFlatMapFunction() + self.assertEqual(list(f.flat_map(1)), ["1"]) + + def test_filter_function_generics(self): + f = TypedFilterFunction() + self.assertTrue(f.filter(1)) + self.assertFalse(f.filter(-1)) + + def test_key_selector_generics(self): + f = TypedKeySelector() + self.assertEqual(f.get_key(1), "1") + +if __name__ == '__main__': + unittest.main()