Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTensor deterministic operations as observations #7656

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
11 changes: 2 additions & 9 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

Expand Down Expand Up @@ -158,13 +157,7 @@ def is_valid_observed(v) -> bool:
return True

return (
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
(
isinstance(v.owner.op, Elemwise)
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
not rvs_in_graph(v)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
Expand Down
6 changes: 6 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

from pymc.logprob.utils import rvs_in_graph

if not inputvars(x) and not rvs_in_graph(x):
cheap_eval_mode = Mode(linker="py", optimizer=None)
return x.eval(mode=cheap_eval_mode)

raise TypeError(f"Data cannot be extracted from {x}")


Expand Down
19 changes: 15 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def test_allowed(self):
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
assert isinstance(mb.owner.op, MinibatchOp)

with pytest.raises(ValueError, match="not valid for Minibatch"):
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
assert isinstance(mb.owner.op, MinibatchOp)

with pytest.raises(ValueError, match="not valid for Minibatch"):
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20):
assert isinstance(mb.owner.op, MinibatchOp)

def test_assert(self):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
Expand All @@ -623,3 +623,14 @@ def test_multiple_vars(self):
[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, -draw_mB)


def test_scaling_data_works_in_likelihood() -> None:
data = np.array([10, 11, 12, 13, 14, 15])

with pm.Model():
target = pm.Data("target", data)
scale = 12
scaled_target = target / scale
mu = pm.Normal("mu", mu=0, sigma=1)
pm.Normal("x", mu=mu, sigma=1, observed=scaled_target)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I sample this to check that it has the correct data in the InferenceData?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we have more direct ways of testing it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well maybe. Just make sure to do a cheap sampling, since we don't care about draws at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested the "extract_..." function directly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a more direct unit test, this would be a CI, we don't care how it's done just want to be sure the data is there in the end?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still want some sampling here to check the outputs? This test is not testing anything explicitly at the moment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add sampling to this test when I can

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

15 changes: 15 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ def test_minibatch_variable(self):
assert isinstance(res, np.ndarray)
np.testing.assert_array_equal(res, y)

def test_pytensor_operations(self):
x = np.array([1, 2, 3])
target = 1 + 3 * pt.as_tensor_variable(x)

res = extract_obs_data(target)
assert isinstance(res, np.ndarray)
np.testing.assert_array_equal(res, np.array([4, 7, 10]))

def test_pytensor_operations_raises(self):
x = pt.scalar("x")
target = 1 + 3 * x

with pytest.raises(TypeError, match="Data cannot be extracted from"):
extract_obs_data(target)


@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
def test_convert_data(input_dtype):
Expand Down
Loading